Commit fa221c5f authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Robert Kimball

Refactor CPUWorkspaceInsertion to simplify its use in MxNet (#988)

* refactor cpworkspaceinsertion for mxnet

* rename mxnet functions to adhere to ngraph naming convention

* define a member static const int in a cpp file to resolve a linking issue
parent a1d78033
......@@ -42,10 +42,10 @@ public:
void initialize_default_passes();
template <typename T, class... Args>
void register_pass(Args... args)
void register_pass(Args&&... args)
{
static_assert(std::is_base_of<pass::PassBase, T>::value, "pass not derived from pass base");
auto pass = std::make_shared<T>(args...);
auto pass = std::make_shared<T>(std::forward<Args>(args)...);
auto pass_base = std::static_pointer_cast<PassBase>(pass);
m_pass_list.push_back(pass_base);
if (m_visualize || m_serialize)
......
......@@ -295,6 +295,9 @@ static const runtime::cpu::OpMap dispatcher{
{TI(ngraph::op::Or), &runtime::cpu::CPU_Emitter::emit<op::Or>},
};
const size_t runtime::cpu::CPU_ExternalFunction::CPU_ExternalFunction::s_memory_pool_alignment =
4096;
runtime::cpu::CPU_ExternalFunction::CPU_ExternalFunction(
const shared_ptr<ngraph::Function>& function, bool release_function)
: m_function(function)
......@@ -322,6 +325,9 @@ void runtime::cpu::CPU_ExternalFunction::compile()
ngraph::pass::Manager pass_manager;
//nv_cwi is required only by some frontends
//in which case they should run this pass(CPUWorkspaceInsertion) explicitly
NodeVector nv_cwi;
pass_manager.register_pass<ngraph::pass::NopElimination>();
pass_manager.register_pass<runtime::cpu::pass::LSTMFusion>();
pass_manager.register_pass<runtime::cpu::pass::RNNFusion>();
......@@ -330,7 +336,7 @@ void runtime::cpu::CPU_ExternalFunction::compile()
pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>();
pass_manager.register_pass<ngraph::pass::CoreFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUWorkspaceInsertion>();
pass_manager.register_pass<runtime::cpu::pass::CPUWorkspaceInsertion>(nv_cwi);
pass_manager.register_pass<runtime::cpu::pass::CPUAssignment>(this);
pass_manager.register_pass<runtime::cpu::pass::CPULayout>(this);
pass_manager.register_pass<runtime::cpu::pass::CPUPostLayoutOptimizations>();
......
......@@ -92,7 +92,7 @@ namespace ngraph
const std::string& get_function_name() const { return m_function_name; }
const std::shared_ptr<ngraph::Function> get_function() { return m_function; }
// Temporary Memory Pool alignment
static const size_t s_memory_pool_alignment = 4096;
static const size_t s_memory_pool_alignment;
protected:
void compile();
......
......@@ -52,7 +52,9 @@
#include "ngraph/runtime/cpu/op/max_pool_with_indices.hpp"
#include "ngraph/runtime/cpu/op/sigmoid.hpp"
void ngraph::runtime::cpu::pass::CPUWorkspaceInsertion::construct_max_pool_with_indices()
using namespace ngraph;
static std::shared_ptr<pattern::Matcher> create_maxpool_with_indices_matcher()
{
Shape shape_data{1, 1, 14};
auto data = std::make_shared<pattern::op::Label>(element::f32, shape_data);
......@@ -66,8 +68,34 @@ void ngraph::runtime::cpu::pass::CPUWorkspaceInsertion::construct_max_pool_with_
max_pool->get_window_movement_strides(),
max_pool->get_padding_below(),
max_pool->get_padding_above());
return std::make_shared<pattern::Matcher>(max_pool_bprop);
}
bool runtime::cpu::pass::CPUWorkspaceInsertion::run_on_function(std::shared_ptr<ngraph::Function> f)
{
auto matcher = create_maxpool_with_indices_matcher();
bool replaced = false;
for (auto n : f->get_ordered_ops())
{
if (n->is_output() || n->is_parameter())
{
continue;
}
pattern::graph_rewrite_callback callback = [data, delta](pattern::Matcher& m) {
if (matcher->match(n) && transform(*matcher))
{
replaced = true;
}
}
return replaced;
}
bool runtime::cpu::pass::CPUWorkspaceInsertion::transform(pattern::Matcher& m)
{
auto data = std::dynamic_pointer_cast<pattern::op::Label>(m.get_pattern()->get_argument(0));
auto delta = std::dynamic_pointer_cast<pattern::op::Label>(m.get_pattern()->get_argument(1));
NGRAPH_DEBUG << "In a callback for construct_max_pool_with_indices against "
<< m.get_match_root()->get_name();
......@@ -130,8 +158,8 @@ void ngraph::runtime::cpu::pass::CPUWorkspaceInsertion::construct_max_pool_with_
}
//create a new max_pool_with_indices_bprop
auto max_pool_with_indices_bprop = std::make_shared<op::MaxPoolWithIndicesBackprop>(
pattern_map[data],
auto max_pool_with_indices_bprop =
std::make_shared<op::MaxPoolWithIndicesBackprop>(pattern_map[data],
pattern_map[delta],
max_pool_with_indices_indices,
m_max_pool->get_window_shape(),
......@@ -140,9 +168,6 @@ void ngraph::runtime::cpu::pass::CPUWorkspaceInsertion::construct_max_pool_with_
m_max_pool->get_padding_above());
ngraph::replace_node(m_max_pool_bprop, max_pool_with_indices_bprop);
m_indices_list.push_back(max_pool_with_indices_indices);
return true;
};
auto m = std::make_shared<pattern::Matcher>(max_pool_bprop, callback);
this->add_matcher(m);
}
......@@ -14,9 +14,27 @@
* limitations under the License.
*******************************************************************************/
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/node_vector.hpp"
#include "ngraph/pass/pass.hpp"
#include "ngraph/pattern/matcher.hpp"
namespace ngraph
{
......@@ -32,15 +50,18 @@ namespace ngraph
}
}
class ngraph::runtime::cpu::pass::CPUWorkspaceInsertion : public ngraph::pass::GraphRewrite
class ngraph::runtime::cpu::pass::CPUWorkspaceInsertion : public ngraph::pass::FunctionPass
{
public:
CPUWorkspaceInsertion()
: GraphRewrite()
CPUWorkspaceInsertion(ngraph::NodeVector& indices_list)
: FunctionPass()
, m_indices_list(indices_list)
{
construct_max_pool_with_indices();
}
virtual bool run_on_function(std::shared_ptr<ngraph::Function> f);
private:
void construct_max_pool_with_indices();
ngraph::NodeVector& m_indices_list;
bool transform(ngraph::pattern::Matcher& m);
};
......@@ -29,6 +29,7 @@
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/max_pool.hpp"
#include "ngraph/op/negative.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/sum.hpp"
......@@ -1194,9 +1195,10 @@ TEST(cpu_fusion, max_pool_with_indices)
}
{
NodeVector nv_cwi;
pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>("max_pool_bprop_before.pdf");
pass_manager.register_pass<runtime::cpu::pass::CPUWorkspaceInsertion>();
pass_manager.register_pass<runtime::cpu::pass::CPUWorkspaceInsertion>(nv_cwi);
pass_manager.register_pass<pass::VisualizeTree>("max_pool_bprop_after.pdf");
pass_manager.run_passes(df);
}
......@@ -1256,9 +1258,10 @@ TEST(cpu_fusion, backwards_maxpool_with_indices_n4_c1_hw4_2x2_max)
auto df = autodiff::backprop_function(f);
{
NodeVector nv_cwi;
pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>("max_pool_bprop_before2.pdf");
pass_manager.register_pass<runtime::cpu::pass::CPUWorkspaceInsertion>();
pass_manager.register_pass<runtime::cpu::pass::CPUWorkspaceInsertion>(nv_cwi);
pass_manager.register_pass<pass::VisualizeTree>("max_pool_bprop_after2.pdf");
pass_manager.run_passes(df);
}
......@@ -1267,6 +1270,118 @@ TEST(cpu_fusion, backwards_maxpool_with_indices_n4_c1_hw4_2x2_max)
ASSERT_TRUE(read_vector<float>(output) == expected);
}
static std::shared_ptr<ngraph::Function> make_forward_function()
{
Shape shape_a{10, 3, 28, 28};
auto input = std::make_shared<op::Parameter>(element::f32, shape_a);
Shape window_shape{2, 2};
auto max_pool = std::make_shared<op::MaxPool>(input, window_shape);
auto neg = std::make_shared<op::Negative>(max_pool);
auto absn = std::make_shared<op::Abs>(max_pool);
return std::make_shared<Function>(NodeVector{max_pool, neg, absn}, op::ParameterVector{input});
}
static std::pair<std::shared_ptr<ngraph::Function>, std::vector<std::shared_ptr<ngraph::Node>>>
make_backward_function(std::shared_ptr<ngraph::Function> f)
{
// get parameters
std::vector<std::shared_ptr<ngraph::op::Parameter>> back_parameters = f->get_parameters();
ngraph::NodeVector adjoints;
ngraph::NodeVector outputs;
for (auto Y : f->get_results())
{
// Get the output
// Create the Adjoint
auto C = std::make_shared<ngraph::op::Parameter>(Y->get_element_type(), Y->get_shape());
outputs.push_back(Y);
adjoints.push_back(C);
}
ngraph::autodiff::Adjoints adjoint{outputs, adjoints};
// Perform autodiff
std::vector<std::shared_ptr<Node>> dYdXs(back_parameters.size());
transform(back_parameters.begin(),
back_parameters.end(),
dYdXs.begin(),
[&adjoint](const std::shared_ptr<Node>& X) { return adjoint.backprop_node(X); });
// create the backward function
std::vector<std::shared_ptr<ngraph::op::Parameter>> param_adjoints;
for (auto n : adjoints)
param_adjoints.push_back(std::dynamic_pointer_cast<ngraph::op::Parameter>(n));
back_parameters.insert(back_parameters.begin(), param_adjoints.begin(), param_adjoints.end());
return {std::make_shared<ngraph::Function>(dYdXs, back_parameters), adjoints};
}
void optimize_graph(std::shared_ptr<ngraph::Function>& f, std::shared_ptr<ngraph::Function> bf)
{
// start by removing excess reshapes
NodeVector nv_cwi;
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<ngraph::pass::ReshapeElimination>();
pass_manager.register_pass<ngraph::pass::ReshapeElimination>();
pass_manager.register_pass<runtime::cpu::pass::CPUWorkspaceInsertion>(nv_cwi);
pass_manager.register_pass<pass::VisualizeTree>("before.fprop_cache.pdf");
pass_manager.run_passes(f);
pass_manager.run_passes(bf);
if (nv_cwi.size() > 0)
{
NodeVector new_outputs;
for (auto r : f->get_results())
{
new_outputs.push_back(r->get_argument(0));
}
new_outputs.insert(new_outputs.end(), nv_cwi.begin(), nv_cwi.end());
f = std::make_shared<ngraph::Function>(new_outputs, f->get_parameters());
}
ngraph::NodeVector dYdXs;
for (size_t i = 0; i < bf->get_output_size(); ++i)
{
dYdXs.push_back(bf->get_output_op(i)->get_argument(0));
}
ngraph::NodeVector combined_outputs;
for (auto r : f->get_results())
{
combined_outputs.push_back(r->get_argument(0));
}
combined_outputs.insert(combined_outputs.end(), dYdXs.begin(), dYdXs.end());
std::vector<std::shared_ptr<ngraph::op::Parameter>> combined_parameters = f->get_parameters();
std::vector<std::shared_ptr<ngraph::op::Parameter>> back_parameters = bf->get_parameters();
combined_parameters.insert(
combined_parameters.end(), back_parameters.begin(), back_parameters.end());
auto combinedf = std::make_shared<ngraph::Function>(combined_outputs, combined_parameters);
// rerun Reshape elimination to help simplify the graph again, run CPUFusion
// this replaces nodes in both f and bf due to shared-ptr - ness
ngraph::pass::Manager pass_manager_comb;
pass_manager_comb.register_pass<ngraph::pass::ReshapeElimination>();
pass_manager_comb.register_pass<ngraph::runtime::cpu::pass::CPUFusion>();
pass_manager_comb.run_passes(combinedf);
}
TEST(cpu_fusion, maxpool_with_indices_in_mxnet)
{
auto f = make_forward_function();
auto bfa = make_backward_function(f);
auto maybe_bf = bfa.first;
auto adjoints = bfa.second;
optimize_graph(f, maybe_bf);
auto fprop_cache = ngraph::cache_fprop(f, maybe_bf, adjoints);
auto mpwi_bprop = fprop_cache.bprop->get_results().at(0)->get_argument(0);
ASSERT_TRUE(std::dynamic_pointer_cast<op::Parameter>(mpwi_bprop->get_argument(0)));
ASSERT_TRUE(std::dynamic_pointer_cast<op::Parameter>(mpwi_bprop->get_argument(2)));
}
TEST(cpu_fusion, batch_norm_folding)
{
Shape shape_input{1, 8, 3, 3};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment