Unverified Commit 8c4ae5ea authored by Adam Procter's avatar Adam Procter Committed by GitHub

Zero padding for convolution (#352)

parent 06f9efd9
...@@ -29,12 +29,16 @@ CoordinateTransform::CoordinateTransform(const Shape& source_shape, ...@@ -29,12 +29,16 @@ CoordinateTransform::CoordinateTransform(const Shape& source_shape,
const Coordinate& source_start_corner, const Coordinate& source_start_corner,
const Coordinate& source_end_corner, const Coordinate& source_end_corner,
const Strides& source_strides, const Strides& source_strides,
const AxisVector& source_axis_order) const AxisVector& source_axis_order,
const Shape& source_padding_below,
const Shape& source_padding_above)
: m_source_shape(source_shape) : m_source_shape(source_shape)
, m_source_start_corner(source_start_corner) , m_source_start_corner(source_start_corner)
, m_source_end_corner(source_end_corner) , m_source_end_corner(source_end_corner)
, m_source_strides(source_strides) , m_source_strides(source_strides)
, m_source_axis_order(source_axis_order) , m_source_axis_order(source_axis_order)
, m_source_padding_below(source_padding_below)
, m_source_padding_above(source_padding_above)
{ {
m_n_axes = source_shape.size(); m_n_axes = source_shape.size();
...@@ -61,6 +65,16 @@ CoordinateTransform::CoordinateTransform(const Shape& source_shape, ...@@ -61,6 +65,16 @@ CoordinateTransform::CoordinateTransform(const Shape& source_shape,
throw std::domain_error( throw std::domain_error(
"Source axis order does not have the same number of axes as the source space shape"); "Source axis order does not have the same number of axes as the source space shape");
} }
if (m_n_axes != source_padding_below.size())
{
throw std::domain_error(
"Padding-below shape does not have the same number of axes as the source space shape");
}
if (m_n_axes != source_padding_above.size())
{
throw std::domain_error(
"Padding-above shape does not have the same number of axes as the source space shape");
}
AxisVector all_axes(m_n_axes); AxisVector all_axes(m_n_axes);
size_t n = 0; size_t n = 0;
...@@ -75,7 +89,8 @@ CoordinateTransform::CoordinateTransform(const Shape& source_shape, ...@@ -75,7 +89,8 @@ CoordinateTransform::CoordinateTransform(const Shape& source_shape,
for (size_t i = 0; i < m_n_axes; i++) for (size_t i = 0; i < m_n_axes; i++)
{ {
if (source_start_corner[i] >= source_shape[i] && if (source_start_corner[i] >=
source_shape[i] + source_padding_below[i] + source_padding_above[i] &&
!(source_start_corner[i] == 0 && source_shape[i] == 0)) !(source_start_corner[i] == 0 && source_shape[i] == 0))
{ {
std::stringstream ss; std::stringstream ss;
...@@ -87,7 +102,8 @@ CoordinateTransform::CoordinateTransform(const Shape& source_shape, ...@@ -87,7 +102,8 @@ CoordinateTransform::CoordinateTransform(const Shape& source_shape,
for (size_t i = 0; i < m_n_axes; i++) for (size_t i = 0; i < m_n_axes; i++)
{ {
if (source_end_corner[i] > source_shape[i]) if (source_end_corner[i] >
source_shape[i] + source_padding_below[i] + source_padding_above[i])
{ {
std::stringstream ss; std::stringstream ss;
...@@ -115,7 +131,27 @@ CoordinateTransform::CoordinateTransform(const Shape& source_shape, ...@@ -115,7 +131,27 @@ CoordinateTransform::CoordinateTransform(const Shape& source_shape,
} }
} }
static AxisVector default_axis_order(size_t n_axes) Shape CoordinateTransform::default_padding(size_t n_axes)
{
return Shape(n_axes, 0);
}
CoordinateTransform::CoordinateTransform(const Shape& source_shape,
const Coordinate& source_start_corner,
const Coordinate& source_end_corner,
const Strides& source_strides,
const AxisVector& source_axis_order)
: CoordinateTransform(source_shape,
source_start_corner,
source_end_corner,
source_strides,
source_axis_order,
default_padding(source_shape.size()),
default_padding(source_shape.size()))
{
}
AxisVector CoordinateTransform::default_axis_order(size_t n_axes)
{ {
AxisVector result(n_axes); AxisVector result(n_axes);
size_t n = 0; size_t n = 0;
...@@ -132,11 +168,13 @@ CoordinateTransform::CoordinateTransform(const Shape& source_shape, ...@@ -132,11 +168,13 @@ CoordinateTransform::CoordinateTransform(const Shape& source_shape,
source_start_corner, source_start_corner,
source_end_corner, source_end_corner,
source_strides, source_strides,
default_axis_order(source_shape.size())) default_axis_order(source_shape.size()),
default_padding(source_shape.size()),
default_padding(source_shape.size()))
{ {
} }
static Strides default_source_strides(size_t n_axes) Strides CoordinateTransform::default_source_strides(size_t n_axes)
{ {
return AxisVector(n_axes, 1); return AxisVector(n_axes, 1);
} }
...@@ -148,16 +186,18 @@ CoordinateTransform::CoordinateTransform(const Shape& source_shape, ...@@ -148,16 +186,18 @@ CoordinateTransform::CoordinateTransform(const Shape& source_shape,
source_start_corner, source_start_corner,
source_end_corner, source_end_corner,
default_source_strides(source_shape.size()), default_source_strides(source_shape.size()),
default_axis_order(source_shape.size())) default_axis_order(source_shape.size()),
default_padding(source_shape.size()),
default_padding(source_shape.size()))
{ {
} }
static Coordinate default_source_start_corner(size_t n_axes) Coordinate CoordinateTransform::default_source_start_corner(size_t n_axes)
{ {
return Coordinate(n_axes, 0); return Coordinate(n_axes, 0);
} }
static Coordinate default_source_end_corner(const Shape& source_shape) Coordinate CoordinateTransform::default_source_end_corner(const Shape& source_shape)
{ {
return source_shape; return source_shape;
} }
...@@ -167,7 +207,9 @@ CoordinateTransform::CoordinateTransform(const Shape& source_shape) ...@@ -167,7 +207,9 @@ CoordinateTransform::CoordinateTransform(const Shape& source_shape)
default_source_start_corner(source_shape.size()), default_source_start_corner(source_shape.size()),
default_source_end_corner(source_shape), default_source_end_corner(source_shape),
default_source_strides(source_shape.size()), default_source_strides(source_shape.size()),
default_axis_order(source_shape.size())) default_axis_order(source_shape.size()),
default_padding(source_shape.size()),
default_padding(source_shape.size()))
{ {
} }
...@@ -204,8 +246,9 @@ Coordinate CoordinateTransform::to_source_coordinate(const Coordinate& c) const ...@@ -204,8 +246,9 @@ Coordinate CoordinateTransform::to_source_coordinate(const Coordinate& c) const
for (size_t axis = 0; axis < m_n_axes; axis++) for (size_t axis = 0; axis < m_n_axes; axis++)
{ {
result[m_source_axis_order[axis]] = result[m_source_axis_order[axis]] = c[axis] * m_source_strides[axis] +
c[axis] * m_source_strides[axis] + m_source_start_corner[axis]; m_source_start_corner[axis] -
m_source_padding_below[axis];
} }
return result; return result;
...@@ -230,6 +273,28 @@ bool CoordinateTransform::in_bounds(const Coordinate& c) const ...@@ -230,6 +273,28 @@ bool CoordinateTransform::in_bounds(const Coordinate& c) const
return true; return true;
} }
// Check if a coordinate corresponds to one of the padding elements that has been added to
// the source space.
bool CoordinateTransform::in_padding(const Coordinate& c) const
{
if (c.size() != m_n_axes)
{
throw std::domain_error("Coordinate rank does not match the coordinate transform rank");
}
for (size_t axis = 0; axis < m_n_axes; axis++)
{
size_t padded_pos = c[axis] * m_source_strides[axis] + m_source_start_corner[axis];
if (padded_pos < m_source_padding_below[axis] ||
padded_pos >= m_source_padding_below[axis] + m_source_shape[axis])
{
return true;
}
}
return false;
}
Coordinate CoordinateTransform::get_target_shape() const Coordinate CoordinateTransform::get_target_shape() const
{ {
return m_target_shape; return m_target_shape;
......
...@@ -26,6 +26,14 @@ namespace ngraph ...@@ -26,6 +26,14 @@ namespace ngraph
class CoordinateTransform class CoordinateTransform
{ {
public: public:
CoordinateTransform(const Shape& source_shape,
const Coordinate& source_start_corner,
const Coordinate& source_end_corner,
const Strides& source_strides,
const AxisVector& source_axis_order,
const Shape& source_padding_below,
const Shape& source_padding_above);
CoordinateTransform(const Shape& source_shape, CoordinateTransform(const Shape& source_shape,
const Coordinate& source_start_corner, const Coordinate& source_start_corner,
const Coordinate& source_end_corner, const Coordinate& source_end_corner,
...@@ -45,6 +53,8 @@ namespace ngraph ...@@ -45,6 +53,8 @@ namespace ngraph
size_t index(const Coordinate& c) const; size_t index(const Coordinate& c) const;
bool in_bounds(const Coordinate& c) const; bool in_bounds(const Coordinate& c) const;
bool in_padding(const Coordinate& c) const;
Coordinate to_source_coordinate(const Coordinate& c) const;
Coordinate get_target_shape() const; Coordinate get_target_shape() const;
Shape get_source_shape() { return m_source_shape; } Shape get_source_shape() { return m_source_shape; }
...@@ -75,14 +85,20 @@ namespace ngraph ...@@ -75,14 +85,20 @@ namespace ngraph
Iterator begin() noexcept { return Iterator(m_target_shape); } Iterator begin() noexcept { return Iterator(m_target_shape); }
Iterator end() noexcept { return Iterator(m_target_shape, true); } Iterator end() noexcept { return Iterator(m_target_shape, true); }
private: private:
Coordinate to_source_coordinate(const Coordinate& c) const;
size_t index_source(const Coordinate& c) const; size_t index_source(const Coordinate& c) const;
static Shape default_padding(size_t n_axes);
static AxisVector default_axis_order(size_t n_axes);
static Strides default_source_strides(size_t n_axes);
static Coordinate default_source_start_corner(size_t n_axes);
static Coordinate default_source_end_corner(const Shape& source_shape);
Shape m_source_shape; Shape m_source_shape;
Shape m_source_start_corner; Shape m_source_start_corner;
Shape m_source_end_corner; Shape m_source_end_corner;
Strides m_source_strides; Strides m_source_strides;
AxisVector m_source_axis_order; AxisVector m_source_axis_order;
Shape m_source_padding_below;
Shape m_source_padding_above;
Shape m_target_shape; Shape m_target_shape;
size_t m_n_axes; size_t m_n_axes;
......
...@@ -21,10 +21,14 @@ using namespace ngraph; ...@@ -21,10 +21,14 @@ using namespace ngraph;
op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch, op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch,
const std::shared_ptr<Node>& filters, const std::shared_ptr<Node>& filters,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Strides& window_dilation_strides) const Strides& window_dilation_strides,
const Shape& padding_below,
const Shape& padding_above)
: RequiresTensorViewArgs("Convolution", {image_batch, filters}) : RequiresTensorViewArgs("Convolution", {image_batch, filters})
, m_window_movement_strides(window_movement_strides) , m_window_movement_strides(window_movement_strides)
, m_window_dilation_strides(window_dilation_strides) , m_window_dilation_strides(window_dilation_strides)
, m_padding_below(padding_below)
, m_padding_above(padding_above)
{ {
auto& image_batch_shape = get_inputs().at(0).get_shape(); auto& image_batch_shape = get_inputs().at(0).get_shape();
auto& filters_shape = get_inputs().at(1).get_shape(); auto& filters_shape = get_inputs().at(1).get_shape();
...@@ -88,15 +92,32 @@ op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch, ...@@ -88,15 +92,32 @@ op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch,
} }
// //
// Extract input image shape Di and make sure all dimensions are larger than 0. // Make sure padding-below and padding-above shapes have same rank as Di.
//
if (m_padding_below.size() != m_image_dimension_count)
{
throw ngraph_error(
"Convolution padding-below rank does not match number of image dimensions.");
}
if (m_padding_above.size() != m_image_dimension_count)
{
throw ngraph_error(
"Convolution padding-above rank does not match number of image dimensions.");
}
//
// Extract input image shape Di and make sure all dimensions are larger than 0 after padding.
// //
for (size_t i = 0; i < m_image_dimension_count; i++) for (size_t i = 0; i < m_image_dimension_count; i++)
{ {
m_input_image_shape.push_back(image_batch_shape[1 + 1 + +i]); m_input_image_shape.push_back(image_batch_shape[1 + 1 + i]);
m_padded_input_image_shape.push_back(padding_below[i] + image_batch_shape[1 + 1 + i] +
padding_above[i]);
if (m_input_image_shape[i] == 0) if (m_padded_input_image_shape[i] == 0)
{ {
throw ngraph_error("Convolution input image dimension is zero."); throw ngraph_error("Convolution input image dimension is zero even with padding.");
} }
} }
...@@ -127,9 +148,10 @@ op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch, ...@@ -127,9 +148,10 @@ op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch,
m_window_physical_shape.push_back( m_window_physical_shape.push_back(
(m_window_virtual_shape[i] - 1) * m_window_dilation_strides[i] + 1); (m_window_virtual_shape[i] - 1) * m_window_dilation_strides[i] + 1);
if (m_window_physical_shape[i] > m_input_image_shape[i]) if (m_window_physical_shape[i] > m_padded_input_image_shape[i])
{ {
throw ngraph_error("Convolution window after dilation is larger than the image."); throw ngraph_error(
"Convolution window after dilation is larger than the image even with padding.");
} }
} }
...@@ -142,8 +164,9 @@ op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch, ...@@ -142,8 +164,9 @@ op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch,
{ {
throw ngraph_error("Convolution window axis movement stride is zero."); throw ngraph_error("Convolution window axis movement stride is zero.");
} }
m_output_image_shape.push_back(ceil_div( m_output_image_shape.push_back(
m_input_image_shape[i] - m_window_physical_shape[i] + 1, m_window_movement_strides[i])); ceil_div(m_padded_input_image_shape[i] - m_window_physical_shape[i] + 1,
m_window_movement_strides[i]));
} }
// //
...@@ -157,7 +180,33 @@ op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch, ...@@ -157,7 +180,33 @@ op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch,
set_value_type_checked(get_inputs().at(0).get_element_type(), result_shape); set_value_type_checked(get_inputs().at(0).get_element_type(), result_shape);
} }
Strides default_strides(const std::shared_ptr<Node>& image_batch) Shape op::Convolution::default_padding(const std::shared_ptr<Node>& image_batch)
{
auto& image_batch_shape = image_batch->get_shape();
if (image_batch_shape.size() < 3)
{
// For consistency we should throw the same error message here that we throw in the constructor.
throw ngraph_error(
"Convolution image batch input must have rank of at least 3 (one batch axis, one "
"input-channel axis, at least one image dimension).");
}
return Shape(image_batch_shape.size() - 2, 0);
}
op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch,
const std::shared_ptr<Node>& filters,
const Strides& window_movement_strides,
const Strides& window_dilation_strides)
: Convolution(image_batch,
filters,
window_movement_strides,
window_dilation_strides,
default_padding(image_batch),
default_padding(image_batch))
{
}
Strides op::Convolution::default_strides(const std::shared_ptr<Node>& image_batch)
{ {
auto& image_batch_shape = image_batch->get_shape(); auto& image_batch_shape = image_batch->get_shape();
if (image_batch_shape.size() < 3) if (image_batch_shape.size() < 3)
...@@ -173,13 +222,23 @@ Strides default_strides(const std::shared_ptr<Node>& image_batch) ...@@ -173,13 +222,23 @@ Strides default_strides(const std::shared_ptr<Node>& image_batch)
op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch, op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch,
const std::shared_ptr<Node>& filters, const std::shared_ptr<Node>& filters,
const Strides& window_movement_strides) const Strides& window_movement_strides)
: Convolution(image_batch, filters, window_movement_strides, default_strides(image_batch)) : Convolution(image_batch,
filters,
window_movement_strides,
default_strides(image_batch),
default_padding(image_batch),
default_padding(image_batch))
{ {
} }
op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch, op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch,
const std::shared_ptr<Node>& filters) const std::shared_ptr<Node>& filters)
: Convolution(image_batch, filters, default_strides(image_batch), default_strides(image_batch)) : Convolution(image_batch,
filters,
default_strides(image_batch),
default_strides(image_batch),
default_padding(image_batch),
default_padding(image_batch))
{ {
} }
...@@ -190,8 +249,12 @@ std::shared_ptr<Node> ...@@ -190,8 +249,12 @@ std::shared_ptr<Node>
{ {
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
} }
return std::make_shared<Convolution>( return std::make_shared<Convolution>(new_args.at(0),
new_args.at(0), new_args.at(1), m_window_movement_strides, m_window_dilation_strides); new_args.at(1),
m_window_movement_strides,
m_window_dilation_strides,
m_padding_below,
m_padding_above);
} }
/* /*
......
...@@ -30,20 +30,25 @@ namespace ngraph ...@@ -30,20 +30,25 @@ namespace ngraph
/// (sometimes called features) and \f$(d^f_1,\dots,d^f_n)\f$ are the filter dimensions. It is required that for all \f$i\f$, \f$0 < l_i(d^f_i - 1) + 1 \le d_i\f$. /// (sometimes called features) and \f$(d^f_1,\dots,d^f_n)\f$ are the filter dimensions. It is required that for all \f$i\f$, \f$0 < l_i(d^f_i - 1) + 1 \le d_i\f$.
/// (See below for the definition of the dilation \f$l_i\f$); /// (See below for the definition of the dilation \f$l_i\f$);
/// ///
/// and two optional parameters: /// and four optional parameters:
/// ///
/// 3. <i>(the window movement strides)</i> a vector of positive integers \f$(s_1,\dots,s_n)\f$, and /// 3. <i>(the window movement strides)</i> a vector of positive integers \f$(s_1,\dots,s_n)\f$ (default is all ones),
/// 4. <i>(the window dilation strides)</i> a vector of positive integers \f$(l_1,\dots,l_n)\f$. /// 4. <i>(the window dilation strides)</i> a vector of positive integers \f$(l_1,\dots,l_n)\f$ (default is all ones),
/// 5. <i>(the padding below)</i> a vector of non-negative integers \f$(p_1,\dots,p_n)\f$ (default is all zeros), and
/// 6. <i>(the padding above)</i> a vector of non-negative integers \f$(q_1,\dots,q_n)\f$ (default is all zeros).
/// ///
/// Define the <i>physical window size</i> as the vector \f$(p_1,\dots,p_n)\f$ where \f$p_i = l_i(d^f_i - 1) + 1\f$. /// The output has the shape \f$(N,C_\textit{out},d'_1,\dots,d'_n)\f$, where \f$d'_n = \lceil \frac{d_i + p_i + q_i - l_i(d^f_i - 1)}{s_i} \rceil\f$.
/// ///
/// The output has the shape \f$(N,C_\textit{out},d'_1,\dots,d'_n)\f$, where \f$d'_n = \lceil \frac{d_i - p_i + 1}{s_i} \rceil\f$. /// Given an input image batch tensor \f$T_\textit{in}\f$, first define the <i>padded input tensor</i> \f$T_\textit{pad}\f$, with shape \f$(N,C_\textit{in},d_1+p_1+q+1,\dots,d_n+p_n+q_n)\f$, as follows:
/// ///
/// Given an input image batch tensor \f$T_\textit{in}\f$ and an input filter tensor \f$T_\textit{filt}\f$, the output tensor is defined by the equation (TODO: I'm sure /// \f[
/// I messed something up here) /// T_\textit{pad}[a,c,i_1,\dots,i_n] = T[a,c,i_1 - p_1,\dots,i_n - p_n] \text{ if for all }k, p_k \le i_k \lt p_k + d_k, \text{ else } 0
/// \f]
///
/// then, given an input filter tensor \f$T_\textit{filt}\f$, the output tensor \f$T_\textit{out}\f$ is defined by the equation.
/// ///
/// \f[ /// \f[
/// T_\textit{out}[a,c_\textit{out},i_1,\dots,i_n] = \sum_{c_\textit{in}=0,j_1=0,\dots,j_n=0}^{c_\textit{in}=C_\textit{in}-1,j_1=d^f_1-1,\dots,j_n=d^f_n-1} (T_\textit{filt}[c_\textit{out},c_\textit{in},j_1,\dots,j_n] \cdot T_\textit{in}[a,c_\textit{in},s_1i_1+l_1j_1,\dots,s_ni_n+l_nj_n]) /// T_\textit{out}[a,c_\textit{out},i_1,\dots,i_n] = \sum_{c_\textit{in}=0,j_1=0,\dots,j_n=0}^{c_\textit{in}=C_\textit{in}-1,j_1=d^f_1-1,\dots,j_n=d^f_n-1} (T_\textit{filt}[c_\textit{out},c_\textit{in},j_1,\dots,j_n] \cdot T_\textit{pad}[a,c_\textit{in},s_1i_1+l_1j_1,\dots,s_ni_n+l_nj_n])
/// \f] /// \f]
/// ///
class Convolution : public RequiresTensorViewArgs class Convolution : public RequiresTensorViewArgs
...@@ -55,6 +60,21 @@ namespace ngraph ...@@ -55,6 +60,21 @@ namespace ngraph
/// \param filters The node producing the filters tensor. /// \param filters The node producing the filters tensor.
/// \param window_movement_strides The window movement strides. /// \param window_movement_strides The window movement strides.
/// \param window_dilation_strides The window dilation strides. /// \param window_dilation_strides The window dilation strides.
/// \param padding_below The padding-below sizes.
/// \param padding_above The padding-above sizes.
Convolution(const std::shared_ptr<Node>& image_batch,
const std::shared_ptr<Node>& filters,
const Strides& window_movement_strides,
const Strides& window_dilation_strides,
const Shape& padding_below,
const Shape& padding_above);
/// \brief Constructs a batched convolution operation with no padding (i.e., padding above and below are 0 everywhere).
///
/// \param image_batch The node producing the input image batch tensor.
/// \param filters The node producing the filters tensor.
/// \param window_movement_strides The window movement strides.
/// \param window_dilation_strides The window dilation strides.
Convolution(const std::shared_ptr<Node>& image_batch, Convolution(const std::shared_ptr<Node>& image_batch,
const std::shared_ptr<Node>& filters, const std::shared_ptr<Node>& filters,
const Strides& window_movement_strides, const Strides& window_movement_strides,
...@@ -83,12 +103,18 @@ namespace ngraph ...@@ -83,12 +103,18 @@ namespace ngraph
const Strides& get_window_movement_strides() const { return m_window_movement_strides; } const Strides& get_window_movement_strides() const { return m_window_movement_strides; }
/// \return The window dilation strides. /// \return The window dilation strides.
const Strides& get_window_dilation_strides() const { return m_window_dilation_strides; } const Strides& get_window_dilation_strides() const { return m_window_dilation_strides; }
/// \return The padding-below sizes.
const Shape& get_padding_below() const { return m_padding_below; }
/// \return The padding-above sizes.
const Strides& get_padding_above() const { return m_padding_above; }
/// \return The number of input channels. /// \return The number of input channels.
size_t get_input_channel_count() const { return m_input_channel_count; } size_t get_input_channel_count() const { return m_input_channel_count; }
/// \return The number of output channels. /// \return The number of output channels.
size_t get_output_channel_count() const { return m_output_channel_count; } size_t get_output_channel_count() const { return m_output_channel_count; }
/// \return The input image shape. /// \return The input image shape, not including padding.
const Shape& get_input_image_shape() const { return m_input_image_shape; } const Shape& get_input_image_shape() const { return m_input_image_shape; }
/// \return The input image shape, including padding.
const Shape& get_padded_input_image_shape() const { return m_padded_input_image_shape; }
/// \return The output image shape. /// \return The output image shape.
const Shape& get_output_image_shape() const { return m_output_image_shape; } const Shape& get_output_image_shape() const { return m_output_image_shape; }
/// \return The physical window shape. /// \return The physical window shape.
...@@ -102,15 +128,23 @@ namespace ngraph ...@@ -102,15 +128,23 @@ namespace ngraph
protected: protected:
Strides m_window_movement_strides; Strides m_window_movement_strides;
Strides m_window_dilation_strides; Strides m_window_dilation_strides;
Shape m_padding_below;
Shape m_padding_above;
// TODO: Some of these values should probably be computed dynamically rather than stored here.
size_t m_input_channel_count; size_t m_input_channel_count;
size_t m_output_channel_count; size_t m_output_channel_count;
Shape m_input_image_shape; Shape m_input_image_shape;
Shape m_padded_input_image_shape;
Shape m_output_image_shape; Shape m_output_image_shape;
Shape m_window_physical_shape; Shape m_window_physical_shape;
Shape m_window_virtual_shape; Shape m_window_virtual_shape;
size_t m_batch_size; size_t m_batch_size;
size_t m_image_dimension_count; size_t m_image_dimension_count;
private:
static Shape default_padding(const std::shared_ptr<Node>& image_batch);
static Strides default_strides(const std::shared_ptr<Node>& image_batch);
}; };
} }
} }
...@@ -1351,7 +1351,9 @@ void runtime::cpu::CPU_Emitter::EmitConvolution(const ngraph::Node* n, ...@@ -1351,7 +1351,9 @@ void runtime::cpu::CPU_Emitter::EmitConvolution(const ngraph::Node* n,
m_out << " {" << join(convolution->get_window_movement_strides()) m_out << " {" << join(convolution->get_window_movement_strides())
<< "},\n"; << "},\n";
m_out << " {" << join(convolution->get_window_dilation_strides()) m_out << " {" << join(convolution->get_window_dilation_strides())
<< "});\n"; << "},\n";
m_out << " {" << join(convolution->get_padding_below()) << "},\n";
m_out << " {" << join(convolution->get_padding_above()) << "});\n";
} }
void runtime::cpu::CPU_Emitter::EmitNot(const ngraph::Node* n, void runtime::cpu::CPU_Emitter::EmitNot(const ngraph::Node* n,
......
...@@ -287,7 +287,9 @@ private: ...@@ -287,7 +287,9 @@ private:
args[1]->get_shape(), args[1]->get_shape(),
out[0]->get_shape(), out[0]->get_shape(),
c->get_window_movement_strides(), c->get_window_movement_strides(),
c->get_window_dilation_strides()); c->get_window_dilation_strides(),
c->get_padding_below(),
c->get_padding_above());
} }
else if (node_op == "Cos") else if (node_op == "Cos")
{ {
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "ngraph/common.hpp" #include "ngraph/common.hpp"
#include "ngraph/coordinate_transform.hpp" #include "ngraph/coordinate_transform.hpp"
#include "ngraph/util.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -33,7 +34,9 @@ namespace ngraph ...@@ -33,7 +34,9 @@ namespace ngraph
const Shape& arg1_shape, const Shape& arg1_shape,
const Shape& out_shape, const Shape& out_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Strides& window_dilation_strides) const Strides& window_dilation_strides,
const Shape& padding_below,
const Shape& padding_above)
{ {
// At the outermost level we will walk over every output coordinate O. // At the outermost level we will walk over every output coordinate O.
CoordinateTransform output_transform(out_shape); CoordinateTransform output_transform(out_shape);
...@@ -60,6 +63,9 @@ namespace ngraph ...@@ -60,6 +63,9 @@ namespace ngraph
// with strides: // with strides:
// //
// (1,l_1,...,l_n). // (1,l_1,...,l_n).
//
// Note that we are iterating within the *padded* image batch, so further down we must check
// the current coordinate is in the padding.
size_t n_image_dimensions = arg0_shape.size() - 2; size_t n_image_dimensions = arg0_shape.size() - 2;
size_t n_input_channels = arg0_shape[1]; size_t n_input_channels = arg0_shape[1];
...@@ -67,6 +73,8 @@ namespace ngraph ...@@ -67,6 +73,8 @@ namespace ngraph
Shape input_batch_transform_start(2 + n_image_dimensions); Shape input_batch_transform_start(2 + n_image_dimensions);
Shape input_batch_transform_end(2 + n_image_dimensions); Shape input_batch_transform_end(2 + n_image_dimensions);
Shape input_batch_transform_strides(2 + n_image_dimensions, 1); Shape input_batch_transform_strides(2 + n_image_dimensions, 1);
Shape input_batch_padding_below(2 + n_image_dimensions, 0);
Shape input_batch_padding_above(2 + n_image_dimensions, 0);
input_batch_transform_start[0] = img_index; input_batch_transform_start[0] = img_index;
input_batch_transform_end[0] = img_index + 1; input_batch_transform_end[0] = img_index + 1;
...@@ -77,17 +85,30 @@ namespace ngraph ...@@ -77,17 +85,30 @@ namespace ngraph
{ {
size_t dilation_stride = window_dilation_strides[i - 2]; size_t dilation_stride = window_dilation_strides[i - 2];
size_t movement_stride = window_movement_strides[i - 2]; size_t movement_stride = window_movement_strides[i - 2];
size_t below_pad = padding_below[i - 2];
size_t above_pad = padding_above[i - 2];
input_batch_transform_start[i] = movement_stride * out_coord[i]; input_batch_transform_start[i] = movement_stride * out_coord[i];
input_batch_transform_end[i] = input_batch_transform_start[i] + input_batch_transform_end[i] = input_batch_transform_start[i] +
(arg1_shape[i] - 1) * dilation_stride + 1; (arg1_shape[i] - 1) * dilation_stride + 1;
input_batch_transform_strides[i] = dilation_stride; input_batch_transform_strides[i] = dilation_stride;
input_batch_padding_below[i] = below_pad;
input_batch_padding_above[i] = above_pad;
} }
AxisVector input_batch_axis_order(2 + n_image_dimensions);
size_t n = 0;
std::generate(input_batch_axis_order.begin(),
input_batch_axis_order.end(),
[&n]() -> size_t { return n++; });
CoordinateTransform input_batch_transform(arg0_shape, CoordinateTransform input_batch_transform(arg0_shape,
input_batch_transform_start, input_batch_transform_start,
input_batch_transform_end, input_batch_transform_end,
input_batch_transform_strides); input_batch_transform_strides,
input_batch_axis_order,
input_batch_padding_below,
input_batch_padding_above);
// Simultaneously with iterating I, for the filters we need to iterate the coordinate: // Simultaneously with iterating I, for the filters we need to iterate the coordinate:
// //
...@@ -130,8 +151,10 @@ namespace ngraph ...@@ -130,8 +151,10 @@ namespace ngraph
{ {
Coordinate input_batch_coord = *input_it++; Coordinate input_batch_coord = *input_it++;
Coordinate filter_coord = *filter_it++; Coordinate filter_coord = *filter_it++;
result += arg0[input_batch_transform.index(input_batch_coord)] * T v = input_batch_transform.in_padding(input_batch_coord)
arg1[filter_transform.index(filter_coord)]; ? 0
: arg0[input_batch_transform.index(input_batch_coord)];
result += v * arg1[filter_transform.index(filter_coord)];
} }
out[output_transform.index(out_coord)] = result; out[output_transform.index(out_coord)] = result;
......
...@@ -370,8 +370,14 @@ static shared_ptr<ngraph::Function> ...@@ -370,8 +370,14 @@ static shared_ptr<ngraph::Function>
node_js.at("window_movement_strides").get<vector<size_t>>(); node_js.at("window_movement_strides").get<vector<size_t>>();
auto window_dilation_strides = auto window_dilation_strides =
node_js.at("window_dilation_strides").get<vector<size_t>>(); node_js.at("window_dilation_strides").get<vector<size_t>>();
node = make_shared<op::Convolution>( auto padding_below = node_js.at("padding_below").get<vector<size_t>>();
args[0], args[1], window_movement_strides, window_dilation_strides); auto padding_above = node_js.at("padding_above").get<vector<size_t>>();
node = make_shared<op::Convolution>(args[0],
args[1],
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above);
} }
else if (node_op == "Cos") else if (node_op == "Cos")
{ {
...@@ -638,6 +644,8 @@ static json write(const Node& n) ...@@ -638,6 +644,8 @@ static json write(const Node& n)
auto tmp = dynamic_cast<const op::Convolution*>(&n); auto tmp = dynamic_cast<const op::Convolution*>(&n);
node["window_movement_strides"] = tmp->get_window_movement_strides(); node["window_movement_strides"] = tmp->get_window_movement_strides();
node["window_dilation_strides"] = tmp->get_window_dilation_strides(); node["window_dilation_strides"] = tmp->get_window_dilation_strides();
node["padding_below"] = tmp->get_padding_below();
node["padding_above"] = tmp->get_padding_above();
} }
else if (node_op == "Cos") else if (node_op == "Cos")
{ {
......
This diff is collapsed.
...@@ -68,20 +68,27 @@ def tuple_times(t1,t2): ...@@ -68,20 +68,27 @@ def tuple_times(t1,t2):
# filter : [Co][Ci][W1]...[Wn] # filter : [Co][Ci][W1]...[Wn]
# move_strides = (s1,...,sn) # move_strides = (s1,...,sn)
# dilation_strides = (l1,...,ln) # dilation_strides = (l1,...,ln)
# below_pads = (p1,...,pn)
# above_pads = (q1,...,qn)
# #
# Returns: # Returns:
# output_batch : [N ][Co][D'1]...[D'n] # output_batch : [N ][Co][D'1]...[D'n]
# #
# Where the D's are computed according to TensorFlow-style "valid" convolution rules. # Where the D's are computed according to TensorFlow-style "valid" convolution rules, but *after* padding.
# See https://www.tensorflow.org/api_docs/python/tf/nn/convolution. # See https://www.tensorflow.org/api_docs/python/tf/nn/convolution.
# #
def convolution_ref(img_batch, filter, move_strides, dilation_strides): def convolution_ref(img_batch, filter, move_strides, dilation_strides, below_pads, above_pads):
assert(len(img_batch.shape) == len(filter.shape)) assert(len(img_batch.shape) == len(filter.shape))
assert(len(img_batch.shape) > 2) assert(len(img_batch.shape) > 2)
assert(img_batch.shape[1] == filter.shape[1]) assert(img_batch.shape[1] == filter.shape[1])
assert(len(move_strides) == len(img_batch.shape) - 2) assert(len(move_strides) == len(img_batch.shape) - 2)
assert(len(dilation_strides) == len(img_batch.shape) - 2) assert(len(dilation_strides) == len(img_batch.shape) - 2)
# Pad the input batch.
below_pads = (0,0) + below_pads # Have to add values for the image and channel dims.
above_pads = (0,0) + above_pads # Have to add values for the image and channel dims.
img_batch = np.pad(img_batch, zip(below_pads,above_pads), mode='constant', constant_values=0)
img_count = img_batch.shape[0] # N img_count = img_batch.shape[0] # N
ci_count = img_batch.shape[1] # Ci ci_count = img_batch.shape[1] # Ci
co_count = filter.shape[0] # Co co_count = filter.shape[0] # Co
...@@ -153,23 +160,22 @@ def data_str(data): ...@@ -153,23 +160,22 @@ def data_str(data):
return result return result
def emit_test(t,f): def emit_test(t,f):
test_name, input_batch_data, filter_data, move_strides, dilation_strides = t test_name, input_batch_data, filter_data, move_strides, dilation_strides, below_pads, above_pads = t
print ("Generating convolution test '%s'..." % test_name) print ("Generating convolution test '%s'..." % test_name)
output_batch_data = convolution_ref(input_batch_data,filter_data,move_strides,dilation_strides) output_batch_data = convolution_ref(input_batch_data,filter_data,move_strides,dilation_strides,below_pads,above_pads)
template = ''' template = '''
TEST (${BACKEND_NAME}, %s) TEST (${BACKEND_NAME}, %s)
{ {
auto shape_a = Shape{%s}; auto shape_a = Shape{%s};
auto A = make_shared<op::Parameter>(element::Float64::element_type(), shape_a); auto A = make_shared<op::Parameter>(element::f64, shape_a);
auto shape_b = Shape{%s}; auto shape_b = Shape{%s};
auto B = make_shared<op::Parameter>(element::Float64::element_type(), shape_b); auto B = make_shared<op::Parameter>(element::f64, shape_b);
auto shape_r = Shape{%s}; auto shape_r = Shape{%s};
auto result_type = make_shared<TensorViewType>(element::Float64::element_type(), shape_r);
auto f = make_shared<Function>( auto f = make_shared<Function>(
make_shared<op::Convolution>(A, B, Strides{%s}, Strides{%s}), result_type, op::Parameters{A, B}); make_shared<op::Convolution>(A, B, Strides{%s}, Strides{%s}, Shape{%s}, Shape{%s}), op::Parameters{A, B});
auto manager = runtime::Manager::get("${BACKEND_NAME}"); auto manager = runtime::Manager::get("${BACKEND_NAME}");
auto external = manager->compile(f); auto external = manager->compile(f);
...@@ -177,11 +183,11 @@ TEST (${BACKEND_NAME}, %s) ...@@ -177,11 +183,11 @@ TEST (${BACKEND_NAME}, %s)
auto cf = backend->make_call_frame(external); auto cf = backend->make_call_frame(external);
// Create some tensors for input/output // Create some tensors for input/output
auto a = backend->make_primary_tensor_view(element::Float64::element_type(), shape_a); auto a = backend->make_primary_tensor_view(element::f64, shape_a);
copy_data(a, vector<double>{%s}); copy_data(a, vector<double>{%s});
auto b = backend->make_primary_tensor_view(element::Float64::element_type(), shape_b); auto b = backend->make_primary_tensor_view(element::f64, shape_b);
copy_data(b, vector<double>{%s}); copy_data(b, vector<double>{%s});
auto result = backend->make_primary_tensor_view(element::Float64::element_type(), shape_r); auto result = backend->make_primary_tensor_view(element::f64, shape_r);
vector<double> expected_result{%s}; vector<double> expected_result{%s};
...@@ -196,22 +202,34 @@ TEST (${BACKEND_NAME}, %s) ...@@ -196,22 +202,34 @@ TEST (${BACKEND_NAME}, %s)
shape_str(output_batch_data.shape), shape_str(output_batch_data.shape),
shape_str(move_strides), shape_str(move_strides),
shape_str(dilation_strides), shape_str(dilation_strides),
shape_str(below_pads),
shape_str(above_pads),
data_str(input_batch_data), data_str(input_batch_data),
data_str(filter_data), data_str(filter_data),
data_str(output_batch_data))); data_str(output_batch_data)));
# test name input image batch filters stride dilation # test name input image batch filters stride dilation below-pads above-pads
tests = [ tests = [
("convolution_2d_1image", shaped_linspace((1,1,3,5)), shaped_linspace((2,1,2,2)), (1,1), (1,1)), ("convolution_2d_1image", shaped_linspace((1,1,3,5)), shaped_linspace((2,1,2,2)), (1,1), (1,1), (0,0), (0,0)),
("convolution_2d_2images", shaped_linspace((2,1,3,5)), shaped_linspace((2,1,2,2)), (1,1), (1,1)), ("convolution_2d_1image_padded_1_1x1_1", shaped_linspace((1,1,3,5)), shaped_linspace((2,1,2,2)), (1,1), (1,1), (1,1), (1,1)),
("convolution_2d_2images_strided", shaped_linspace((2,1,3,5)), shaped_linspace((2,1,2,2)), (2,2), (1,1)), ("convolution_2d_1image_padded_2_3x4_5", shaped_linspace((1,1,3,5)), shaped_linspace((2,1,2,2)), (1,1), (1,1), (2,3), (4,5)),
("convolution_2d_2images_dilated", shaped_linspace((2,1,3,5)), shaped_linspace((2,1,2,2)), (1,1), (2,2)), ("convolution_2d_2images", shaped_linspace((2,1,3,5)), shaped_linspace((2,1,2,2)), (1,1), (1,1), (0,0), (0,0)),
("convolution_3d_2images", shaped_linspace((2,1,3,5,8)), shaped_linspace((2,1,2,2,3)), (1,1,1), (1,1,1)), ("convolution_2d_2images_strided", shaped_linspace((2,1,3,5)), shaped_linspace((2,1,2,2)), (2,2), (1,1), (0,0), (0,0)),
("convolution_4d_2images", shaped_linspace((2,1,3,5,8,7)),shaped_linspace((2,1,2,2,3,1)),(1,1,1,1),(1,1,1,1)), ("convolution_2d_2images_strided_padded", shaped_linspace((2,1,3,5)), shaped_linspace((2,1,2,2)), (2,2), (1,1), (4,2), (5,7)),
("convolution_4d_4images", shaped_linspace((4,3,3,5,8,7)),shaped_linspace((4,3,2,2,3,1)),(1,1,1,1),(1,1,1,1)), ("convolution_2d_2images_strided_padded_same",
("convolution_4d_4images_strided", shaped_linspace((4,3,3,5,8,7)),shaped_linspace((4,3,2,2,3,1)),(2,1,3,2),(1,1,1,1)), shaped_linspace((2,1,3,5)), shaped_linspace((2,1,2,2)), (2,2), (1,1), (2,2), (2,2)),
("convolution_4d_4images_dilated", shaped_linspace((4,3,3,5,8,7)),shaped_linspace((4,3,2,2,3,1)),(1,1,1,1),(2,1,3,2)), ("convolution_2d_2images_dilated", shaped_linspace((2,1,3,5)), shaped_linspace((2,1,2,2)), (1,1), (2,2), (0,0), (0,0)),
("convolution_4d_4images_strided_dilated",shaped_linspace((4,3,8,8,8,8)),shaped_linspace((4,3,2,2,3,1)),(3,2,2,3),(2,1,3,2)), ("convolution_2d_2images_dilated_padded", shaped_linspace((2,1,3,5)), shaped_linspace((2,1,2,2)), (1,1), (2,2), (4,2), (5,7)),
("convolution_3d_2images", shaped_linspace((2,1,3,5,8)), shaped_linspace((2,1,2,2,3)), (1,1,1), (1,1,1), (0,0,0), (0,0,0)),
("convolution_4d_2images", shaped_linspace((2,1,3,5,8,7)),shaped_linspace((2,1,2,2,3,1)),(1,1,1,1),(1,1,1,1),(0,0,0,0), (0,0,0,0)),
("convolution_4d_4images", shaped_linspace((4,3,3,5,8,7)),shaped_linspace((4,3,2,2,3,1)),(1,1,1,1),(1,1,1,1),(0,0,0,0), (0,0,0,0)),
("convolution_4d_4images_strided", shaped_linspace((4,3,3,5,8,7)),shaped_linspace((4,3,2,2,3,1)),(2,1,3,2),(1,1,1,1),(0,0,0,0), (0,0,0,0)),
("convolution_4d_4images_dilated", shaped_linspace((4,3,3,5,8,7)),shaped_linspace((4,3,2,2,3,1)),(1,1,1,1),(2,1,3,2),(0,0,0,0), (0,0,0,0)),
("convolution_4d_4images_strided_dilated",shaped_linspace((4,3,8,8,8,8)),shaped_linspace((4,3,2,2,3,1)),(3,2,2,3),(2,1,3,2),(0,0,0,0), (0,0,0,0)),
("convolution_4d_4images_strided_dilated_padded",
shaped_linspace((4,3,8,8,8,8)),shaped_linspace((4,3,2,2,3,1)),(3,2,2,3),(2,1,3,2),(2,4,6,8), (1,3,5,7)),
("convolution_4d_4images_strided_dilated_padded_same",
shaped_linspace((4,3,8,8,8,8)),shaped_linspace((4,3,2,2,3,1)),(3,2,2,3),(2,1,3,2),(3,3,3,3), (3,3,3,3)),
] ]
def main(): def main():
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment