Commit 61be3814 authored by Adam Procter's avatar Adam Procter Committed by Robert Kimball

Refactor convolution and pooling type prop (#1817)

* WIP

* More WIP

* More chiseling

* Move conv validation utils to a separate file; update unit tests

* Fix invalid attributes in pattern containing ConvolutionBackpropFilters

* Remove zero_const_conv test (it's no longer possible to construct the graph being tested)

* Rename infer_convolution_output_item_shape to infer_windowed_reduction_output_shape and add a boolean flag to control whether window-all-in-padding is allowed

* Add generalized function for inferring pooling fprop, use it in AvgPool/AvgPoolBackprop

* Update MaxPool to use new utility functions

* Fix comment

* Remove faulty and redundant check for window shape relative to pre-padding data shape

* Revert change to pattern construction in cpu_fusion

* Update unit test for maxpool

* Restore unjustly eliminated tests; move some computation to ptrdiff_t for safety; fix wording on some error messages

* Formatting
parent 05aa1be8
......@@ -155,6 +155,7 @@ set (SRC
strides.cpp
type/element_type.cpp
util.cpp
validation_util.cpp
graph_util.cpp
placement.cpp
cpio.cpp
......
This diff is collapsed.
This diff is collapsed.
......@@ -356,6 +356,9 @@ namespace ngraph
namespace util
{
// This is a legacy function, retained because the CPU backend uses it for now.
// TODO: Update CPU backend to use the new stuff in validation_util.hpp, and remove
// this function.
Shape infer_convolution_output_shape(const Node* node,
const Shape& data_batch_shape,
const Shape& filters_shape,
......
This diff is collapsed.
This diff is collapsed.
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <tuple>
#include "ngraph/coordinate_diff.hpp"
#include "ngraph/op/op.hpp"
namespace ngraph
{
Shape infer_windowed_reduction_output_shape(const Node* node,
const Shape& data_shape,
const Strides& data_dilation,
const CoordinateDiff& data_padding_below,
const CoordinateDiff& data_padding_above,
const Shape& window_shape,
const Strides& window_strides,
const Strides& window_dilation,
bool is_window_all_in_padding_allowed);
std::tuple<element::Type, Shape>
infer_convolution_forward(const Node* node,
element::Type et_batch,
element::Type et_filters,
const Shape& data_batch_shape,
const Strides& data_dilation,
const CoordinateDiff& data_padding_below,
const CoordinateDiff& data_padding_above,
const Shape& filters_shape,
const Strides& filter_strides,
const Strides& filter_dilation);
Shape infer_batched_pooling_forward(const Node* node,
const Shape& data_batch_shape,
const CoordinateDiff& data_padding_below,
const CoordinateDiff& data_padding_above,
const Shape& window_shape,
const Strides& window_strides,
bool is_window_all_in_padding_allowed);
}
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment