Commit 67c0488b authored by Chris Sullivan's avatar Chris Sullivan Committed by Scott Cyphers

Add MaxPool as an explicit argument to MaxPoolBackprop (#2065)

* remove forward op

* fix bbrks

* fix pybind c-tor for max_pool_bprop

* Add new c-tor to MaxPoolBackprop that takes MaxPool as an explicit argument. Add serializer support.

* Add nvgpu support for new backward pooling c-tor, and calculate fprop when it isn't available.

* Add extra layout for 3 arg maxpool backprop.

* Formatting.

* cpu_workspace_insertion to expect 3-arg maxpool bprop

* GPU: add bprop_needs_pooling flag to primitive hash

* Update INTELGPU arguments_check for MaxPoolBackprop and GPU invocation for avg pool.
parent a708df68
...@@ -50,6 +50,5 @@ void regclass_pyngraph_op_MaxPoolBackprop(py::module m) ...@@ -50,6 +50,5 @@ void regclass_pyngraph_op_MaxPoolBackprop(py::module m)
const ngraph::Shape&, const ngraph::Shape&,
const ngraph::Strides&, const ngraph::Strides&,
const ngraph::Shape&, const ngraph::Shape&,
const ngraph::Shape&, const ngraph::Shape&>());
const std::shared_ptr<ngraph::op::MaxPool>&>());
} }
...@@ -102,14 +102,28 @@ op::MaxPoolBackprop::MaxPoolBackprop(const shared_ptr<Node>& arg_forward, ...@@ -102,14 +102,28 @@ op::MaxPoolBackprop::MaxPoolBackprop(const shared_ptr<Node>& arg_forward,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
const Shape& padding_above, const Shape& padding_above)
const shared_ptr<op::MaxPool>& forward_op)
: Op("MaxPoolBackprop", check_single_output_args({arg_forward, delta})) : Op("MaxPoolBackprop", check_single_output_args({arg_forward, delta}))
, m_window_shape(window_shape) , m_window_shape(window_shape)
, m_window_movement_strides(window_movement_strides) , m_window_movement_strides(window_movement_strides)
, m_padding_below(padding_below) , m_padding_below(padding_below)
, m_padding_above(padding_above) , m_padding_above(padding_above)
, m_forward_op(forward_op) {
constructor_validate_and_infer_types();
}
op::MaxPoolBackprop::MaxPoolBackprop(const shared_ptr<Node>& arg_forward,
const shared_ptr<Node>& delta,
const shared_ptr<Node>& result_forward,
const Shape& window_shape,
const Strides& window_movement_strides,
const Shape& padding_below,
const Shape& padding_above)
: Op("MaxPoolBackprop", check_single_output_args({arg_forward, delta, result_forward}))
, m_window_shape(window_shape)
, m_window_movement_strides(window_movement_strides)
, m_padding_below(padding_below)
, m_padding_above(padding_above)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
...@@ -152,14 +166,20 @@ void op::MaxPoolBackprop::validate_and_infer_types() ...@@ -152,14 +166,20 @@ void op::MaxPoolBackprop::validate_and_infer_types()
set_output_type(0, get_input_element_type(0), forward_arg_shape); set_output_type(0, get_input_element_type(0), forward_arg_shape);
} }
shared_ptr<op::MaxPool> op::MaxPoolBackprop::get_forward_op() const
{
return m_forward_op.lock();
}
shared_ptr<Node> op::MaxPoolBackprop::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::MaxPoolBackprop::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
if (this->get_arguments().size() == 3)
{
return make_shared<op::MaxPoolBackprop>(new_args.at(0),
new_args.at(1),
new_args.at(2),
m_window_shape,
m_window_movement_strides,
m_padding_below,
m_padding_above);
}
return make_shared<op::MaxPoolBackprop>(new_args.at(0), return make_shared<op::MaxPoolBackprop>(new_args.at(0),
new_args.at(1), new_args.at(1),
m_window_shape, m_window_shape,
...@@ -176,11 +196,11 @@ void op::MaxPool::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVect ...@@ -176,11 +196,11 @@ void op::MaxPool::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVect
auto backprop = auto backprop =
make_shared<op::MaxPoolBackprop>(operand, make_shared<op::MaxPoolBackprop>(operand,
delta, delta,
static_pointer_cast<op::MaxPool>(shared_from_this()),
m_window_shape, m_window_shape,
m_window_movement_strides, m_window_movement_strides,
m_padding_below, m_padding_below,
m_padding_above, m_padding_above);
static_pointer_cast<op::MaxPool>(shared_from_this()));
adjoints.add_delta(operand, backprop); adjoints.add_delta(operand, backprop);
} }
...@@ -92,8 +92,15 @@ namespace ngraph ...@@ -92,8 +92,15 @@ namespace ngraph
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
const Shape& padding_above, const Shape& padding_above);
const std::shared_ptr<op::MaxPool>& forward_op = nullptr);
MaxPoolBackprop(const std::shared_ptr<Node>& arg_forward,
const std::shared_ptr<Node>& delta,
const std::shared_ptr<Node>& result_forward,
const Shape& window_shape,
const Strides& window_movement_strides,
const Shape& padding_below,
const Shape& padding_above);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
...@@ -104,17 +111,11 @@ namespace ngraph ...@@ -104,17 +111,11 @@ namespace ngraph
const Strides& get_window_movement_strides() const { return m_window_movement_strides; } const Strides& get_window_movement_strides() const { return m_window_movement_strides; }
const Shape& get_padding_below() const { return m_padding_below; } const Shape& get_padding_below() const { return m_padding_below; }
const Shape& get_padding_above() const { return m_padding_above; } const Shape& get_padding_above() const { return m_padding_above; }
/// \return A pointer to the corresponding `MaxPool` forward prop op. This may be
/// `nullptr` if no such pointer was provided at construction time, or if the
/// forward op has been freed due to graph rewriting.
std::shared_ptr<op::MaxPool> get_forward_op() const;
protected: protected:
Shape m_window_shape; Shape m_window_shape;
Strides m_window_movement_strides; Strides m_window_movement_strides;
Shape m_padding_below; Shape m_padding_below;
Shape m_padding_above; Shape m_padding_above;
std::weak_ptr<op::MaxPool> m_forward_op;
}; };
} }
} }
...@@ -1187,6 +1187,10 @@ namespace ngraph ...@@ -1187,6 +1187,10 @@ namespace ngraph
{ {
i_mds.push_back(fwd_prim_desc.workspace_primitive_desc().desc()); i_mds.push_back(fwd_prim_desc.workspace_primitive_desc().desc());
} }
else if (node->get_input_size() == 3)
{
i_mds.push_back(diff_dst_desc);
}
o_mds.push_back(prim_desc.diff_src_primitive_desc().desc()); o_mds.push_back(prim_desc.diff_src_primitive_desc().desc());
} }
......
...@@ -61,9 +61,11 @@ static std::shared_ptr<pattern::Matcher> create_maxpool_with_indices_matcher() ...@@ -61,9 +61,11 @@ static std::shared_ptr<pattern::Matcher> create_maxpool_with_indices_matcher()
Shape window_shape{3}; Shape window_shape{3};
auto max_pool = std::make_shared<op::MaxPool>(data, window_shape); auto max_pool = std::make_shared<op::MaxPool>(data, window_shape);
auto delta = std::make_shared<pattern::op::Label>(element::f32, max_pool->get_shape()); auto delta = std::make_shared<pattern::op::Label>(element::f32, max_pool->get_shape());
auto max_pool_label = std::make_shared<pattern::op::Label>(element::f32, max_pool->get_shape());
auto max_pool_bprop = auto max_pool_bprop =
std::make_shared<op::MaxPoolBackprop>(data, std::make_shared<op::MaxPoolBackprop>(data,
delta, delta,
max_pool_label,
max_pool->get_window_shape(), max_pool->get_window_shape(),
max_pool->get_window_movement_strides(), max_pool->get_window_movement_strides(),
max_pool->get_padding_below(), max_pool->get_padding_below(),
...@@ -96,10 +98,12 @@ bool runtime::cpu::pass::CPUWorkspaceInsertion::transform(pattern::Matcher& m) ...@@ -96,10 +98,12 @@ bool runtime::cpu::pass::CPUWorkspaceInsertion::transform(pattern::Matcher& m)
{ {
auto data = std::static_pointer_cast<pattern::op::Label>(m.get_pattern()->get_argument(0)); auto data = std::static_pointer_cast<pattern::op::Label>(m.get_pattern()->get_argument(0));
auto delta = std::static_pointer_cast<pattern::op::Label>(m.get_pattern()->get_argument(1)); auto delta = std::static_pointer_cast<pattern::op::Label>(m.get_pattern()->get_argument(1));
auto max_pool = std::static_pointer_cast<pattern::op::Label>(m.get_pattern()->get_argument(2));
NGRAPH_DEBUG << "In a callback for construct_max_pool_with_indices against " NGRAPH_DEBUG << "In a callback for construct_max_pool_with_indices against "
<< m.get_match_root()->get_name(); << m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
auto m_max_pool = std::static_pointer_cast<op::MaxPool>(pattern_map[max_pool]);
auto m_max_pool_bprop = std::static_pointer_cast<op::MaxPoolBackprop>(m.get_match_root()); auto m_max_pool_bprop = std::static_pointer_cast<op::MaxPoolBackprop>(m.get_match_root());
if (m_max_pool_bprop->get_shape().size() != 4 || if (m_max_pool_bprop->get_shape().size() != 4 ||
...@@ -110,31 +114,6 @@ bool runtime::cpu::pass::CPUWorkspaceInsertion::transform(pattern::Matcher& m) ...@@ -110,31 +114,6 @@ bool runtime::cpu::pass::CPUWorkspaceInsertion::transform(pattern::Matcher& m)
return false; return false;
} }
// find the original MaxPool now
std::shared_ptr<op::MaxPool> m_max_pool;
for (auto u : pattern_map[data]->get_users())
{
if (auto mp = std::dynamic_pointer_cast<op::MaxPool>(u))
{
if (mp->get_window_shape() == m_max_pool_bprop->get_window_shape() &&
mp->get_window_movement_strides() ==
m_max_pool_bprop->get_window_movement_strides() &&
mp->get_padding_below() == m_max_pool_bprop->get_padding_below() &&
mp->get_padding_above() == m_max_pool_bprop->get_padding_above())
{
m_max_pool = mp;
break;
}
}
}
if (!m_max_pool)
{
NGRAPH_DEBUG << "MaxPool for " << pattern_map[data]->get_name() << " and "
<< m_max_pool_bprop->get_name() << " not found";
return false;
}
auto max_pool_with_indices = auto max_pool_with_indices =
std::make_shared<op::MaxPoolWithIndices>(pattern_map[data], std::make_shared<op::MaxPoolWithIndices>(pattern_map[data],
m_max_pool->get_window_shape(), m_max_pool->get_window_shape(),
......
...@@ -878,7 +878,7 @@ size_t runtime::gpu::CUDNNEmitter::build_primitive(const op::MaxPool* node) ...@@ -878,7 +878,7 @@ size_t runtime::gpu::CUDNNEmitter::build_primitive(const op::MaxPool* node)
/// end asymmetric padding detection /// end asymmetric padding detection
size_t max_pool_index = build_pooling(CUDNN_POOLING_MAX, size_t max_pool_index = build_pooling(CUDNN_POOLING_MAX,
output_type, out[0].get_element_type(),
CUDNNEmitter::Prop::Forward, CUDNNEmitter::Prop::Forward,
input_shape_padded, input_shape_padded,
result_shape, result_shape,
...@@ -1521,22 +1521,28 @@ size_t runtime::gpu::CUDNNEmitter::build_convolution_backward_filter( ...@@ -1521,22 +1521,28 @@ size_t runtime::gpu::CUDNNEmitter::build_convolution_backward_filter(
} }
size_t runtime::gpu::CUDNNEmitter::build_pooling(const cudnnPoolingMode_t& pool_op, size_t runtime::gpu::CUDNNEmitter::build_pooling(const cudnnPoolingMode_t& pool_op,
const std::string& dtype, const element::Type& dtype,
const Prop& direction, const Prop& direction,
const Shape& input_shape, const Shape& input_shape,
const Shape& output_shape, const Shape& output_shape,
const Strides& window_strides, const Strides& window_strides,
const Shape& window_shape, const Shape& window_shape,
const Shape& padding_below, const Shape& padding_below,
const Shape& padding_above) const Shape& padding_above,
bool bprop_needs_pooling)
{ {
// construct hash to determine if kernel needs to be emitted // construct hash to determine if kernel needs to be emitted
// or if it already exists in the primitive list // or if it already exists in the primitive list
std::stringstream ss; std::stringstream ss;
ss << "pool_op" << pool_op << "dtype_" << dtype << "_dir" << static_cast<int>(direction) << "_i" ss << "pool_op" << pool_op << "dtype_" << dtype.c_type_string() << "_dir"
<< join(input_shape, "_") << "_o" << join(output_shape, "_") << "_ws" << static_cast<int>(direction) << "_i" << join(input_shape, "_") << "_o"
<< join(window_shape, "_") << "_wst" << join(window_strides, "_") << "_pb" << join(output_shape, "_") << "_ws" << join(window_shape, "_") << "_wst"
<< join(padding_below, "_") << "_pb" << join(padding_above, "_"); << join(window_strides, "_") << "_pb" << join(padding_below, "_") << "_pa"
<< join(padding_above, "_");
if (bprop_needs_pooling)
{
ss << "_fprop_bprop";
}
std::string hash = ss.str(); std::string hash = ss.str();
// check if the requested kernel is already an inserted primitive // check if the requested kernel is already an inserted primitive
...@@ -1546,7 +1552,7 @@ size_t runtime::gpu::CUDNNEmitter::build_pooling(const cudnnPoolingMode_t& pool_ ...@@ -1546,7 +1552,7 @@ size_t runtime::gpu::CUDNNEmitter::build_pooling(const cudnnPoolingMode_t& pool_
return primitive_index; return primitive_index;
} }
const cudnnDataType_t data_type = get_cudnn_datatype(dtype); const cudnnDataType_t data_type = get_cudnn_datatype(dtype.c_type_string());
const cudnnTensorFormat_t tensor_format = CUDNN_TENSOR_NCHW; const cudnnTensorFormat_t tensor_format = CUDNN_TENSOR_NCHW;
auto& desc = m_descriptors.build<cudnnPoolingDescriptor_t>(); auto& desc = m_descriptors.build<cudnnPoolingDescriptor_t>();
auto& input_desc = tensor_descriptor_from_shape(input_shape, data_type, tensor_format); auto& input_desc = tensor_descriptor_from_shape(input_shape, data_type, tensor_format);
...@@ -1620,17 +1626,32 @@ size_t runtime::gpu::CUDNNEmitter::build_pooling(const cudnnPoolingMode_t& pool_ ...@@ -1620,17 +1626,32 @@ size_t runtime::gpu::CUDNNEmitter::build_pooling(const cudnnPoolingMode_t& pool_
{ {
throw std::runtime_error("Pooling does not support int type by cuDNN."); throw std::runtime_error("Pooling does not support int type by cuDNN.");
} }
pool.reset(new gpu::primitive{
[=, &desc, &input_desc, &output_desc](void** inputs, void** outputs) { if (bprop_needs_pooling)
// cuDNN requires the output tensor of the maxpool fprop to be passed even though {
// it is not mathematically necessary. It appears, however, that it is not actually GPUAllocator allocator = this->m_primitive_emitter->get_memory_allocator();
// used as the adjoints are passed in place and the correct result is achieved. auto workspace_size_in_bytes = shape_size(output_shape) * dtype.size();
size_t workspace_idx = allocator.reserve_workspace(workspace_size_in_bytes);
pool.reset(new gpu::primitive{[=, &desc, &input_desc, &output_desc](void** inputs,
void** outputs) {
void* pooling_output = runtime::gpu::invoke_memory_primitive(m_ctx, workspace_idx);
CUDNN_SAFE_CALL(cudnnPoolingForward(*m_ctx->cudnn_handle,
desc,
alpha,
input_desc,
inputs[0],
beta,
output_desc,
pooling_output));
debug_sync();
CUDNN_SAFE_CALL(cudnnPoolingBackward(*m_ctx->cudnn_handle, CUDNN_SAFE_CALL(cudnnPoolingBackward(*m_ctx->cudnn_handle,
desc, desc,
alpha, alpha,
// output (wrt maxpool) tensor // output (wrt maxpool) tensor
output_desc, output_desc,
inputs[1], pooling_output,
// adjoint of output // adjoint of output
output_desc, output_desc,
inputs[1], inputs[1],
...@@ -1643,6 +1664,31 @@ size_t runtime::gpu::CUDNNEmitter::build_pooling(const cudnnPoolingMode_t& pool_ ...@@ -1643,6 +1664,31 @@ size_t runtime::gpu::CUDNNEmitter::build_pooling(const cudnnPoolingMode_t& pool_
outputs[0])); outputs[0]));
debug_sync(); debug_sync();
}}); }});
}
else
{
pool.reset(new gpu::primitive{
[=, &desc, &input_desc, &output_desc](void** inputs, void** outputs) {
CUDNN_SAFE_CALL(cudnnPoolingBackward(*m_ctx->cudnn_handle,
desc,
alpha,
// output (wrt maxpool) tensor
output_desc,
inputs[2],
// adjoint of output
output_desc,
inputs[1],
// input (wrt maxpool) tensor
input_desc,
inputs[0],
beta,
// adjoint of input
input_desc,
outputs[0]));
debug_sync();
}});
}
break; break;
} }
} }
......
...@@ -128,14 +128,15 @@ namespace ngraph ...@@ -128,14 +128,15 @@ namespace ngraph
const double beta); const double beta);
size_t build_pooling(const cudnnPoolingMode_t& pool_op, size_t build_pooling(const cudnnPoolingMode_t& pool_op,
const std::string& dtype, const element::Type& dtype,
const Prop& direction, const Prop& direction,
const ngraph::Shape& input_shape, const ngraph::Shape& input_shape,
const ngraph::Shape& output_shape, const ngraph::Shape& output_shape,
const ngraph::Strides& window_strides, const ngraph::Strides& window_strides,
const ngraph::Shape& window_shape, const ngraph::Shape& window_shape,
const ngraph::Shape& padding_below, const ngraph::Shape& padding_below,
const ngraph::Shape& padding_above); const ngraph::Shape& padding_above,
bool bprop_needs_pooling = false);
size_t build_batchnorm(const cudnnBatchNormMode_t& bn_op, size_t build_batchnorm(const cudnnBatchNormMode_t& bn_op,
const std::string& dtype, const std::string& dtype,
......
...@@ -270,7 +270,7 @@ void runtime::gpu::GPU_Emitter::emit_AvgPool(EMIT_ARGS) ...@@ -270,7 +270,7 @@ void runtime::gpu::GPU_Emitter::emit_AvgPool(EMIT_ARGS)
: CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING; : CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING;
index = cudnn_emitter->build_pooling(cudnn_avg_type, index = cudnn_emitter->build_pooling(cudnn_avg_type,
out[0].get_type(), out[0].get_element_type(),
CUDNNEmitter::Prop::Forward, CUDNNEmitter::Prop::Forward,
input_shape, input_shape,
result_shape, result_shape,
...@@ -308,7 +308,7 @@ void runtime::gpu::GPU_Emitter::emit_AvgPoolBackprop(EMIT_ARGS) ...@@ -308,7 +308,7 @@ void runtime::gpu::GPU_Emitter::emit_AvgPoolBackprop(EMIT_ARGS)
: CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING; : CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING;
auto index = cudnn_emitter->build_pooling(cudnn_avg_type, auto index = cudnn_emitter->build_pooling(cudnn_avg_type,
out[0].get_type(), out[0].get_element_type(),
CUDNNEmitter::Prop::Backward, CUDNNEmitter::Prop::Backward,
output_shape, output_shape,
delta_shape, delta_shape,
...@@ -321,7 +321,7 @@ void runtime::gpu::GPU_Emitter::emit_AvgPoolBackprop(EMIT_ARGS) ...@@ -321,7 +321,7 @@ void runtime::gpu::GPU_Emitter::emit_AvgPoolBackprop(EMIT_ARGS)
// the forward pass but does not use them. It also behaves differently // the forward pass but does not use them. It also behaves differently
// for max pool vs avg pool. The repetition of args below is to address // for max pool vs avg pool. The repetition of args below is to address
// this interface in a way that supports both max and avg pooling // this interface in a way that supports both max and avg pooling
writer << "void* input[] = {" << node_names(args, {0, 0}) << "};\n"; writer << "void* input[] = {" << node_names(args, {0, 0, 0}) << "};\n";
writer << "void* output[] = {" << node_names(out) << "};\n"; writer << "void* output[] = {" << node_names(out) << "};\n";
writer << "gpu::invoke_primitive(ctx, " << index << ", input, output);\n"; writer << "gpu::invoke_primitive(ctx, " << index << ", input, output);\n";
} }
...@@ -799,17 +799,19 @@ void runtime::gpu::GPU_Emitter::emit_MaxPoolBackprop(EMIT_ARGS) ...@@ -799,17 +799,19 @@ void runtime::gpu::GPU_Emitter::emit_MaxPoolBackprop(EMIT_ARGS)
auto& cudnn_emitter = external_function->get_primitive_emitter()->get_cudnn_emitter(); auto& cudnn_emitter = external_function->get_primitive_emitter()->get_cudnn_emitter();
bool needs_fprop = (args.size() != 3);
if (fp_input_shape.size() >= 4) if (fp_input_shape.size() >= 4)
{ {
auto index = cudnn_emitter->build_pooling(CUDNN_POOLING_MAX, auto index = cudnn_emitter->build_pooling(CUDNN_POOLING_MAX,
out[0].get_type(), out[0].get_element_type(),
CUDNNEmitter::Prop::Backward, CUDNNEmitter::Prop::Backward,
fp_input_shape, fp_input_shape,
fp_output_shape, fp_output_shape,
mpb->get_window_movement_strides(), mpb->get_window_movement_strides(),
mpb->get_window_shape(), mpb->get_window_shape(),
mpb->get_padding_below(), mpb->get_padding_below(),
mpb->get_padding_above()); mpb->get_padding_above(),
needs_fprop);
writer << "void* input[] = {" << node_names(args) << "};\n"; writer << "void* input[] = {" << node_names(args) << "};\n";
writer << "void* output[] = {" << node_names(out) << "};\n"; writer << "void* output[] = {" << node_names(out) << "};\n";
......
...@@ -695,7 +695,14 @@ bool runtime::intelgpu::IntelGPUBackend::compile(shared_ptr<Function> func) ...@@ -695,7 +695,14 @@ bool runtime::intelgpu::IntelGPUBackend::compile(shared_ptr<Function> func)
} }
case OP_TYPEID::MaxPoolBackprop: case OP_TYPEID::MaxPoolBackprop:
{ {
arguments_check(op, 2, 1); if (op->get_input_size() == 3)
{
arguments_check(op, 3, 1);
}
else
{
arguments_check(op, 2, 1);
}
const shared_ptr<op::MaxPoolBackprop> max_pool_b = const shared_ptr<op::MaxPoolBackprop> max_pool_b =
static_pointer_cast<op::MaxPoolBackprop>(op); static_pointer_cast<op::MaxPoolBackprop>(op);
......
...@@ -831,12 +831,25 @@ static shared_ptr<ngraph::Function> ...@@ -831,12 +831,25 @@ static shared_ptr<ngraph::Function>
node_js.at("window_movement_strides").get<vector<size_t>>(); node_js.at("window_movement_strides").get<vector<size_t>>();
auto padding_below = node_js.at("padding_below").get<vector<size_t>>(); auto padding_below = node_js.at("padding_below").get<vector<size_t>>();
auto padding_above = node_js.at("padding_above").get<vector<size_t>>(); auto padding_above = node_js.at("padding_above").get<vector<size_t>>();
node = make_shared<op::MaxPoolBackprop>(args[0], if (args.size() == 3)
args[1], {
window_shape, node = make_shared<op::MaxPoolBackprop>(args[0],
window_movement_strides, args[1],
padding_below, args[2],
padding_above); window_shape,
window_movement_strides,
padding_below,
padding_above);
}
else
{
node = make_shared<op::MaxPoolBackprop>(args[0],
args[1],
window_shape,
window_movement_strides,
padding_below,
padding_above);
}
break; break;
} }
case OP_TYPEID::Maximum: case OP_TYPEID::Maximum:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment