Commit a9686f94 authored by Adam Straw's avatar Adam Straw Committed by Robert Kimball

adding i32 as quantized type (#2050)

* adding i32 as quantized type

* code format

* mask gpu unit tests

* unused variable

* intel gpu unit test manifest

* fix typo in unit test manifest
parent 2ebacf5e
......@@ -142,6 +142,41 @@ namespace ngraph
throw ngraph_error("Unsupported dequantization element type");
}
}
else if (args[0].get_element_type() == element::i32)
{
if (out[0].get_element_type() == element::f32)
{
functor = [&, arg0_shape, arg1_shape, daxes](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
ngraph::runtime::reference::dequantize<int32_t>(
static_cast<int32_t*>(arg0_tensor),
static_cast<float*>(arg1_tensor),
static_cast<int32_t*>(arg2_tensor),
static_cast<float*>(out_tensor),
arg0_shape,
arg1_shape,
daxes);
};
}
else if (out[0].get_element_type() == element::f64)
{
functor = [&, arg0_shape, arg1_shape, daxes](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
ngraph::runtime::reference::dequantize<int32_t>(
static_cast<int32_t*>(arg0_tensor),
static_cast<double*>(arg1_tensor),
static_cast<int32_t*>(arg2_tensor),
static_cast<double*>(out_tensor),
arg0_shape,
arg1_shape,
daxes);
};
}
else
{
throw ngraph_error("Unsupported dequantization element type");
}
}
else
{
throw ngraph_error("Unsupported input element type");
......@@ -235,6 +270,21 @@ namespace ngraph
round_mode);
};
}
else if (out[0].get_element_type() == element::i32)
{
functor = [&, arg0_shape, arg1_shape, daxes, round_mode](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
ngraph::runtime::reference::quantize<float>(
static_cast<float*>(arg0_tensor),
static_cast<float*>(arg1_tensor),
static_cast<int32_t*>(arg2_tensor),
static_cast<int32_t*>(out_tensor),
arg0_shape,
arg1_shape,
daxes,
round_mode);
};
}
else
{
throw ngraph_error("Unsupported quantization element type");
......@@ -272,6 +322,21 @@ namespace ngraph
round_mode);
};
}
else if (out[0].get_element_type() == element::i32)
{
functor = [&, arg0_shape, arg1_shape, daxes, round_mode](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
ngraph::runtime::reference::quantize<double>(
static_cast<double*>(arg0_tensor),
static_cast<double*>(arg1_tensor),
static_cast<int32_t*>(arg2_tensor),
static_cast<int32_t*>(out_tensor),
arg0_shape,
arg1_shape,
daxes,
round_mode);
};
}
else
{
throw ngraph_error("Unsupported quantization element type");
......
......@@ -785,6 +785,12 @@ namespace ngraph
if (offset[0] != 0)
return;
}
if (node->get_input_element_type(0) == element::i32)
{
auto offset = offset_const_op->get_vector<int32_t>();
if (offset[0] != 0)
return;
}
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
......@@ -818,6 +824,14 @@ namespace ngraph
return;
}
}
if (node->get_output_element_type(0) == element::i32)
{
auto offset = offset_const_op->get_vector<int32_t>();
if (offset[0] != 0)
{
return;
}
}
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
......
......@@ -16,6 +16,7 @@ shape_of_scalar
shape_of_vector
shape_of_matrix
shape_of_5d
quantize_clamp_int32
# this one just started failing
batchnorm_bprop_n4c3h2w2
......@@ -30,15 +30,25 @@ backwards_maxpool_n2_c1_hw5_3x3_str2_max
backwards_avgpool_n1_c1_hw2x2
backwards_avgpool_n1_c1_hw4x4
backwards_avgpool_n2_c2_hw4x4
quantize
quantize_axes
quantize_int8
quantize_clamp
dequantize
dequantize_zero_offset
dequantize_axes
dequantize_int8
dequantize_zero_offset
dequantize_int8_zero_offset
dequantize_int32
dequantize_int32_zero_offset
quantize
quantize_zero_offset
quantize_axes
quantize_int8
quantize_int8_zero_offset
quantize_int32
quantize_int32_zero_offset
quantize_clamp_uint8
quantize_clamp_int8
quantize_clamp_int32
quantize_ROUND_NEAREST_TOWARD_ZERO
quantize_ROUND_NEAREST_TOWARD_INFINITY
quantize_ROUND_NEAREST_UPWARD
quantize_ROUND_NEAREST_DOWNWARD
quantize_ROUND_NEAREST_TOWARD_EVEN
......
......@@ -18,9 +18,12 @@ backwards_tanh
batch_norm_one_output
batch_norm_three_outputs
dequantize
dequantize_zero_offset
dequantize_axes
dequantize_int8
dequantize_zero_offset
dequantize_int8_zero_offset
dequantize_int32
dequantize_int32_zero_offset
divide_by_zero_int32
dot_3d_multi_axis
dot_4d_5d_multi_axis
......@@ -31,17 +34,24 @@ max_pool_3d
numeric_double_inf
numeric_double_nan
quantize
quantize_ROUND_DOWN
quantize_ROUND_NEAREST_DOWNWARD
quantize_ROUND_NEAREST_TOWARD_EVEN
quantize_zero_offset
quantize_axes
quantize_int8
quantize_int8_zero_offset
quantize_int32
quantize_int32_zero_offset
quantize_clamp_uint8
quantize_clamp_int8
quantize_clamp_int32
quantize_ROUND_NEAREST_TOWARD_ZERO
quantize_ROUND_NEAREST_TOWARD_INFINITY
quantize_ROUND_NEAREST_UPWARD
quantize_ROUND_NEAREST_DOWNWARD
quantize_ROUND_NEAREST_TOWARD_EVEN
quantize_ROUND_TOWARD_INFINITY
quantize_ROUND_TOWARD_ZERO
quantize_ROUND_UP
quantize_axes
quantize_clamp
quantize_int8
quantize_ROUND_DOWN
reduce_window_emulating_max_pool_1d_1channel_1image
reduce_window_emulating_max_pool_1d_1channel_2image
reduce_window_emulating_max_pool_1d_2channel_2image
......
......@@ -971,6 +971,17 @@ private:
quantize->get_axes(),
quantize->get_round_mode());
}
else if (type == element::i32)
{
reference::quantize<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<const int32_t*>(args[2]),
static_cast<int32_t*>(out[0]),
node.get_input_shape(0),
node.get_input_shape(1),
quantize->get_axes(),
quantize->get_round_mode());
}
else
{
std::stringstream ss;
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment