Commit 94d39716 authored by Nishant Patel's avatar Nishant Patel Committed by Scott Cyphers

Generic Reference Convolution (#2840)

* Generalize types in general convolution

* type-o

* rounding

* Do prod wide

* templatize conv in cpu/kernel & add u8u8 support for Qconv

* Remove cast function

* Avoid compiler warning

* Merge problem
parent 8e798add
...@@ -49,6 +49,11 @@ op::QuantizedConvolution::QuantizedConvolution(const shared_ptr<Node>& data_batc ...@@ -49,6 +49,11 @@ op::QuantizedConvolution::QuantizedConvolution(const shared_ptr<Node>& data_batc
auto output_et = requantize ? element::i8 : element::i32; auto output_et = requantize ? element::i8 : element::i32;
if (data_batch->get_element_type() == element::u8 && filters->get_element_type() == element::u8)
{
output_et = element::u8;
}
set_output_type(0, set_output_type(0,
output_et, output_et,
util::infer_convolution_output_shape(this, util::infer_convolution_output_shape(this,
......
...@@ -88,10 +88,10 @@ namespace ngraph ...@@ -88,10 +88,10 @@ namespace ngraph
} }
else else
{ {
std::function<decltype(runtime::cpu::kernel::convolution<float>)> kernel; std::function<decltype(runtime::cpu::kernel::convolution<float, float, float>)>
kernel;
SELECT_KERNEL( kernel = runtime::cpu::kernel::convolution<float, float, float>;
kernel, out[0].get_element_type(), runtime::cpu::kernel::convolution);
auto window_movement_strides = convolution->get_window_movement_strides(); auto window_movement_strides = convolution->get_window_movement_strides();
auto window_dilation_strides = convolution->get_window_dilation_strides(); auto window_dilation_strides = convolution->get_window_dilation_strides();
...@@ -123,7 +123,8 @@ namespace ngraph ...@@ -123,7 +123,8 @@ namespace ngraph
window_dilation_strides, window_dilation_strides,
padding_below, padding_below,
padding_above, padding_above,
data_dilation_strides); data_dilation_strides,
1.0f);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
} }
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include "ngraph/op/experimental/quantized_conv_relu.hpp" #include "ngraph/op/experimental/quantized_conv_relu.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp" #include "ngraph/runtime/cpu/cpu_builder.hpp"
#include "ngraph/runtime/cpu/cpu_executor.hpp" #include "ngraph/runtime/cpu/cpu_executor.hpp"
#include "ngraph/runtime/cpu/kernel/convolution.hpp"
#include "ngraph/runtime/cpu/mkldnn_invoke.hpp" #include "ngraph/runtime/cpu/mkldnn_invoke.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp" #include "ngraph/runtime/cpu/mkldnn_utils.hpp"
...@@ -35,19 +36,24 @@ namespace ngraph ...@@ -35,19 +36,24 @@ namespace ngraph
template <> template <>
void Builder::BUILDER_DECL(ngraph::op::QuantizedConvolution) void Builder::BUILDER_DECL(ngraph::op::QuantizedConvolution)
{ {
auto qconvolution = static_cast<const ngraph::op::QuantizedConvolution*>(node);
auto& functors = external_function->get_functors();
auto arg0_shape = args[0].get_shape();
auto arg1_shape = args[1].get_shape();
auto result_shape = out[0].get_shape();
auto arg0_buffer_index = external_function->get_buffer_index(args[0].get_name());
auto arg1_buffer_index = external_function->get_buffer_index(args[1].get_name());
auto arg2_buffer_index = external_function->get_buffer_index(args[2].get_name());
auto out0_buffer_index = external_function->get_buffer_index(out[0].get_name());
auto scales_size = shape_size(args[2].get_shape());
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node)) if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{ {
auto& functors = external_function->get_functors();
auto arg0_buffer_index =
external_function->get_buffer_index(args[0].get_name());
auto arg1_buffer_index =
external_function->get_buffer_index(args[1].get_name());
auto arg2_buffer_index =
external_function->get_buffer_index(args[2].get_name());
auto out0_buffer_index = external_function->get_buffer_index(out[0].get_name());
auto& mkldnn_emitter = external_function->get_mkldnn_emitter(); auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto scales_size = shape_size(args[2].get_shape());
auto conv_desc = auto conv_desc =
mkldnn_emitter mkldnn_emitter
...@@ -101,7 +107,51 @@ namespace ngraph ...@@ -101,7 +107,51 @@ namespace ngraph
} }
else else
{ {
throw ngraph_error("unsupported parameters for QuantizedConvolution via DEX"); std::function<decltype(
runtime::cpu::kernel::convolution<uint8_t, uint8_t, uint8_t, int32_t>)>
kernel;
kernel = runtime::cpu::kernel::convolution<uint8_t, uint8_t, uint8_t, int32_t>;
auto window_movement_strides = qconvolution->get_window_movement_strides();
auto window_dilation_strides = qconvolution->get_window_dilation_strides();
auto padding_below = qconvolution->get_padding_below();
auto padding_above = qconvolution->get_padding_above();
auto data_dilation_strides = qconvolution->get_data_dilation_strides();
auto functor = [&,
kernel,
arg0_shape,
arg1_shape,
arg0_buffer_index,
arg1_buffer_index,
arg2_buffer_index,
out0_buffer_index,
result_shape,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides,
scales_size](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
vector<float> dyn_scales;
dyn_scales.assign(static_cast<float*>(ctx->buffer_data[arg2_buffer_index]),
static_cast<float*>(ctx->buffer_data[arg2_buffer_index]) +
scales_size);
kernel(ctx->buffer_data[arg0_buffer_index],
ctx->buffer_data[arg1_buffer_index],
ctx->buffer_data[out0_buffer_index],
arg0_shape,
arg1_shape,
result_shape,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides,
dyn_scales[0]);
};
functors.emplace_back(functor);
} }
} }
......
...@@ -27,7 +27,11 @@ namespace ngraph ...@@ -27,7 +27,11 @@ namespace ngraph
{ {
namespace kernel namespace kernel
{ {
template <typename ElementType> template <typename INPUT,
typename FILTER,
typename OUTPUT,
typename ACCUMULATION =
typename ngraph::runtime::reference::widen<OUTPUT>::type>
void convolution(void* input0, void convolution(void* input0,
void* input1, void* input1,
void* output, void* output,
...@@ -38,19 +42,22 @@ namespace ngraph ...@@ -38,19 +42,22 @@ namespace ngraph
const Strides& window_dilation_strides, const Strides& window_dilation_strides,
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above, const CoordinateDiff& padding_above,
const Strides& data_dilation_strides) const Strides& data_dilation_strides,
const float requant_scale)
{ {
reference::convolution<ElementType>(static_cast<const ElementType*>(input0), reference::convolution<INPUT, FILTER, OUTPUT, ACCUMULATION>(
static_cast<const ElementType*>(input1), static_cast<const INPUT*>(input0),
static_cast<ElementType*>(output), static_cast<const FILTER*>(input1),
arg0_shape, static_cast<OUTPUT*>(output),
arg1_shape, arg0_shape,
result_shape, arg1_shape,
window_movement_strides, result_shape,
window_dilation_strides, window_movement_strides,
padding_below, window_dilation_strides,
padding_above, padding_below,
data_dilation_strides); padding_above,
data_dilation_strides,
requant_scale);
} }
template <typename ElementType> template <typename ElementType>
......
...@@ -16,7 +16,9 @@ ...@@ -16,7 +16,9 @@
#pragma once #pragma once
#include <cfenv>
#include <cmath> #include <cmath>
#include <functional>
#include "ngraph/axis_vector.hpp" #include "ngraph/axis_vector.hpp"
#include "ngraph/coordinate_transform.hpp" #include "ngraph/coordinate_transform.hpp"
...@@ -29,13 +31,34 @@ namespace ngraph ...@@ -29,13 +31,34 @@ namespace ngraph
{ {
namespace reference namespace reference
{ {
template <typename T>
struct widen
{
using type = T;
};
template <>
struct widen<float>
{
using type = double;
};
template <>
struct widen<double>
{
using type = long double;
};
// in: NC_I... // in: NC_I...
// filter: C_OC_I... // filter: C_OC_I...
// out: NC_O... // out: NC_O...
template <typename T> template <typename INPUT,
void general_convolution(const T* in, typename FILTER,
const T* filter, typename OUTPUT,
T* out, typename ACCUMULATION = typename widen<OUTPUT>::type>
void general_convolution(const INPUT* in,
const FILTER* filter,
OUTPUT* out,
const Shape& in_shape, const Shape& in_shape,
const Shape& filter_shape, const Shape& filter_shape,
const Shape& out_shape, const Shape& out_shape,
...@@ -49,8 +72,11 @@ namespace ngraph ...@@ -49,8 +72,11 @@ namespace ngraph
size_t filter_out_channel_axis, size_t filter_out_channel_axis,
size_t filter_in_channel_axis, size_t filter_in_channel_axis,
size_t out_batch_axis, size_t out_batch_axis,
size_t out_channel_axis) size_t out_channel_axis,
const float requant_scale = 1.0f)
{ {
auto old_mode = std::fegetround();
std::fesetround(FE_TONEAREST);
// Comments throughout assume without loss of generality that: // Comments throughout assume without loss of generality that:
// //
// * batch axes for both in and out are 0 // * batch axes for both in and out are 0
...@@ -164,7 +190,7 @@ namespace ngraph ...@@ -164,7 +190,7 @@ namespace ngraph
// //
// out[O] += in[I] * filter[F]. // out[O] += in[I] * filter[F].
T result = 0; ACCUMULATION result = 0;
CoordinateTransform::Iterator in_it = in_transform.begin(); CoordinateTransform::Iterator in_it = in_transform.begin();
CoordinateTransform::Iterator filter_it = filter_transform.begin(); CoordinateTransform::Iterator filter_it = filter_transform.begin();
...@@ -185,8 +211,8 @@ namespace ngraph ...@@ -185,8 +211,8 @@ namespace ngraph
size_t filter_idx = filter_transform.index(filter_coord); size_t filter_idx = filter_transform.index(filter_coord);
for (size_t in_channel = 0; in_channel < n_in_channels; ++in_channel) for (size_t in_channel = 0; in_channel < n_in_channels; ++in_channel)
{ {
T in_v = in[in_idx]; ACCUMULATION in_v = in[in_idx];
T f_v = filter[filter_idx]; ACCUMULATION f_v = filter[filter_idx];
result += in_v * f_v; result += in_v * f_v;
in_idx += in_channel_stride; in_idx += in_channel_stride;
filter_idx += filter_in_channel_stride; filter_idx += filter_in_channel_stride;
...@@ -195,15 +221,19 @@ namespace ngraph ...@@ -195,15 +221,19 @@ namespace ngraph
++in_it; ++in_it;
++filter_it; ++filter_it;
} }
out[out_transform.index(out_coord)] =
out[out_transform.index(out_coord)] = result; static_cast<OUTPUT>(result * requant_scale);
} }
std::fesetround(old_mode);
} }
template <typename T> template <typename INPUT,
void convolution(const T* in, typename FILTER,
const T* filter, typename OUTPUT,
T* out, typename ACCUMULATION = typename widen<OUTPUT>::type>
void convolution(const INPUT* in,
const FILTER* filter,
OUTPUT* out,
const Shape& in_shape, const Shape& in_shape,
const Shape& filter_shape, const Shape& filter_shape,
const Shape& out_shape, const Shape& out_shape,
...@@ -211,31 +241,37 @@ namespace ngraph ...@@ -211,31 +241,37 @@ namespace ngraph
const Strides& filter_dilation, const Strides& filter_dilation,
const CoordinateDiff& in_pad_below, const CoordinateDiff& in_pad_below,
const CoordinateDiff& in_pad_above, const CoordinateDiff& in_pad_above,
const Strides& in_dilation) const Strides& in_dilation,
const float requant_scale = 1.0f)
{ {
general_convolution(in, general_convolution<INPUT, FILTER, OUTPUT, ACCUMULATION>(in,
filter, filter,
out, out,
in_shape, in_shape,
filter_shape, filter_shape,
out_shape, out_shape,
stride, stride,
filter_dilation, filter_dilation,
in_pad_below, in_pad_below,
in_pad_above, in_pad_above,
in_dilation, in_dilation,
0, 0,
1, 1,
0, 0,
1, 1,
0, 0,
1); 1,
requant_scale);
} }
template <typename T> template <typename INPUT,
void convolution_backprop_filter(const T* in, typename OUTPUT,
const T* delta_out, typename FILTER,
T* delta_filter, typename ACCUMULATION = typename widen<FILTER>::type>
void convolution_backprop_filter(const INPUT* in,
const OUTPUT* delta_out,
FILTER* delta_filter,
const Shape& in_shape, const Shape& in_shape,
const Shape& out_shape, const Shape& out_shape,
const Shape& filter_shape, const Shape& filter_shape,
...@@ -245,29 +281,32 @@ namespace ngraph ...@@ -245,29 +281,32 @@ namespace ngraph
const CoordinateDiff& backprop_in_pad_above, const CoordinateDiff& backprop_in_pad_above,
const Strides& in_dilation) const Strides& in_dilation)
{ {
general_convolution(in, general_convolution<INPUT, OUTPUT, FILTER, ACCUMULATION>(in,
delta_out, delta_out,
delta_filter, delta_filter,
in_shape, in_shape,
out_shape, out_shape,
filter_shape, filter_shape,
filter_dilation, filter_dilation,
stride, stride,
in_pad_below, in_pad_below,
backprop_in_pad_above, backprop_in_pad_above,
in_dilation, in_dilation,
1, 1,
0, 0,
1, 1,
0, 0,
1, 1,
0); 0);
} }
template <typename T> template <typename OUTPUT,
void convolution_backprop_in(const T* delta_out, typename FILTER,
const T* filter, typename INPUT,
T* delta_in, typename ACCUMULATION = typename widen<INPUT>::type>
void convolution_backprop_in(const OUTPUT* delta_out,
const FILTER* filter,
INPUT* delta_in,
const Shape& out_shape, const Shape& out_shape,
const Shape& filter_shape, const Shape& filter_shape,
const Shape& in_shape, const Shape& in_shape,
...@@ -279,31 +318,32 @@ namespace ngraph ...@@ -279,31 +318,32 @@ namespace ngraph
{ {
// Note that we only reverse the spatial dimensions here (loop // Note that we only reverse the spatial dimensions here (loop
// starts at 2) // starts at 2)
std::vector<T> reversed(shape_size(filter_shape)); std::vector<INPUT> reversed(shape_size(filter_shape));
AxisSet reverse_axes; AxisSet reverse_axes;
for (size_t i = 2; i < filter_shape.size(); ++i) for (size_t i = 2; i < filter_shape.size(); ++i)
{ {
reverse_axes.insert(i); reverse_axes.insert(i);
} }
reverse<T>(filter, &reversed[0], filter_shape, filter_shape, reverse_axes); reverse<FILTER>(filter, &reversed[0], filter_shape, filter_shape, reverse_axes);
general_convolution(delta_out, general_convolution<OUTPUT, FILTER, INPUT, ACCUMULATION>(
&reversed[0], delta_out,
delta_in, &reversed[0],
out_shape, delta_in,
filter_shape, out_shape,
in_shape, filter_shape,
in_dilation, in_shape,
filter_dilation, in_dilation,
backward_delta_out_pad_below, filter_dilation,
backward_delta_out_pad_above, backward_delta_out_pad_below,
stride, backward_delta_out_pad_above,
0, stride,
1, 0,
1, 1,
0, 1,
0, 0,
1); 0,
1);
} }
} // namespace reference } // namespace reference
} // namespace runtime } // namespace runtime
......
...@@ -1423,3 +1423,49 @@ TEST(builder, dynamic_scaled_QD_with_bias) ...@@ -1423,3 +1423,49 @@ TEST(builder, dynamic_scaled_QD_with_bias)
EXPECT_EQ((vector<uint8_t>{178, 231, 255, 255, 0, 255, 255, 255, 255, 255, 0, 255}), EXPECT_EQ((vector<uint8_t>{178, 231, 255, 255, 0, 255, 255, 255, 255, 255, 0, 255}),
read_vector<uint8_t>(f_requantize_relu_r)); read_vector<uint8_t>(f_requantize_relu_r));
} }
TEST(builder, scaled_QC_u8u8)
{
Shape shape_a{1, 1, 3, 4}; // input shape
Shape shape_b{1, 1, 3, 3}; // filter shape
Shape shape_r{1, 1, 3, 4}; // output shape
vector<uint8_t> a_data = {1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4}; //{-1, -2, 3, 2, 4, 1, 0, 1, 0};
vector<uint8_t> b_data = {1, 2, 3, 4, 5, 0, 0, 1, 2}; //{0, -1, 0, -2, -3, 5, 0, 2, 1};
auto A = make_shared<op::Parameter>(element::u8, shape_a);
auto B = make_shared<op::Parameter>(element::u8, shape_b);
auto scale = op::Constant::create(element::f32, Shape{}, {2});
auto CV = make_shared<ngraph::op::QuantizedConvolution>(A,
B,
Strides{1, 1}, // move_strides
Strides{1, 1}, // filter_dilation
CoordinateDiff{1, 1}, // below_pads
CoordinateDiff{1, 1}, // above_pads
Strides{1, 1}, // data_dilation
scale,
false);
auto f = make_shared<Function>(NodeVector{CV}, ParameterVector{A, B});
constant_fold(f);
auto backend = runtime::Backend::create("CPU");
// Create some tensors for input/output
auto a = backend->create_tensor(element::u8, shape_a);
copy_data(a, a_data);
auto b = backend->create_tensor(element::u8, shape_b);
copy_data(b, b_data);
auto result = backend->create_tensor(element::u8, shape_r);
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a, b});
EXPECT_EQ((vector<uint8_t>{22 * 2,
34 * 2,
30 * 2,
32 * 2,
38 * 2,
72 * 2,
90 * 2,
43 * 2,
33 * 2,
52 * 2,
43 * 2,
39 * 2} /*{1, 28, -3, 16, -7, -14, 3, -7, -3}*/),
read_vector<uint8_t>(result));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment