Commit ecb1608b authored by gaurides's avatar gaurides Committed by Scott Cyphers

Fix data type for conv builder (#2885)

* Fix data type for conv builder

* Added support for other data types
parent 02c56a01
......@@ -91,7 +91,8 @@ namespace ngraph
std::function<decltype(runtime::cpu::kernel::convolution<float, float, float>)>
kernel;
kernel = runtime::cpu::kernel::convolution<float, float, float>;
SELECT_KERNEL_3ARGS(
kernel, out[0].get_element_type(), runtime::cpu::kernel::convolution);
auto window_movement_strides = convolution->get_window_movement_strides();
auto window_dilation_strides = convolution->get_window_dilation_strides();
......
......@@ -76,6 +76,52 @@
KV = K<uint64_t>; \
}
#define SELECT_KERNEL_3ARGS(KV, ET, K) \
if (ET == element::boolean) \
{ \
KV = K<char, char, char>; \
} \
else if (ET == element::f32) \
{ \
KV = K<float, float, float>; \
} \
else if (ET == element::f64) \
{ \
KV = K<double, double, double>; \
} \
else if (ET == element::i8) \
{ \
KV = K<int8_t, int8_t, int8_t>; \
} \
else if (ET == element::i16) \
{ \
KV = K<int16_t, int16_t, int16_t>; \
} \
else if (ET == element::i32) \
{ \
KV = K<int32_t, int32_t, int32_t>; \
} \
else if (ET == element::i64) \
{ \
KV = K<int64_t, int64_t, int64_t>; \
} \
else if (ET == element::u8) \
{ \
KV = K<uint8_t, uint8_t, uint8_t>; \
} \
else if (ET == element::u16) \
{ \
KV = K<uint16_t, uint16_t, uint16_t>; \
} \
else if (ET == element::u32) \
{ \
KV = K<uint32_t, uint32_t, uint32_t>; \
} \
else if (ET == element::u64) \
{ \
KV = K<uint64_t, uint64_t, uint64_t>; \
}
#define SELECT_RANK(KV, ET, R, K) \
if (R == 1) \
KV = K<ET, 1>; \
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment