Commit 6b3f3a0a authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Robert Kimball

extend cse to handle backend ops (#1972)

* extend cse to handle backend ops

* revert back to static casts
parent ee6444ed
......@@ -175,8 +175,12 @@ static std::unordered_map<std::type_index,
class NodeKey
{
public:
NodeKey(std::shared_ptr<Node> n)
NodeKey(std::shared_ptr<Node> n,
std::unordered_map<std::type_index,
std::function<bool(std::shared_ptr<Node>, std::shared_ptr<Node>)>>&
backend_handlers)
: m_node(n)
, m_backend_handlers(backend_handlers)
{
}
......@@ -191,17 +195,30 @@ public:
return false;
}
{
auto eh = ops_to_cse_handlers.find(TI(p_this));
if (eh == ops_to_cse_handlers.end())
if (eh != ops_to_cse_handlers.end())
{
return false;
return eh->second(m_node, other.get_node());
}
}
{
auto eh = m_backend_handlers.find(TI(p_this));
if (eh != m_backend_handlers.end())
{
return eh->second(m_node, other.get_node());
}
}
return false;
}
private:
std::shared_ptr<Node> m_node;
std::unordered_map<std::type_index,
std::function<bool(std::shared_ptr<Node>, std::shared_ptr<Node>)>>&
m_backend_handlers;
};
namespace std
......@@ -254,7 +271,7 @@ bool ngraph::pass::CommonSubexpressionElimination::run_on_function(
continue;
}
NodeKey n_key{n};
NodeKey n_key(n, m_backend_cse_handlers);
if (expressions.count(n_key))
{
ngraph::replace_node(n, expressions.at(n_key));
......
......@@ -34,5 +34,18 @@ public:
{
}
CommonSubexpressionElimination(
const std::unordered_map<std::type_index,
std::function<bool(std::shared_ptr<Node>, std::shared_ptr<Node>)>>&
backend_cse_handlers)
: FunctionPass()
, m_backend_cse_handlers(backend_cse_handlers)
{
}
std::unordered_map<std::type_index,
std::function<bool(std::shared_ptr<Node>, std::shared_ptr<Node>)>>
m_backend_cse_handlers;
virtual bool run_on_function(std::shared_ptr<ngraph::Function> f);
};
......@@ -28,6 +28,7 @@ set(SRC
cpu_tensor_view.cpp
cpu_tracing.cpp
cpu_visualize_tree.cpp
cpu_cse.cpp
cpu_debugger.cpp
builder/add.cpp
builder/allreduce.cpp
......
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "cpu_cse.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp"
#include "ngraph/runtime/cpu/cpu_op_annotations.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
#include "ngraph/runtime/cpu/op/convert_layout.hpp"
using namespace mkldnn;
using namespace ngraph;
using namespace std;
#define TI(x) std::type_index(typeid(x))
static bool cse_convertlayout(std::shared_ptr<Node> a, std::shared_ptr<Node> b)
{
return false;
}
namespace ngraph
{
namespace runtime
{
namespace cpu
{
const std::unordered_map<
std::type_index,
std::function<bool(std::shared_ptr<Node>, std::shared_ptr<Node>)>>&
get_cse_handlers_map()
{
const static std::unordered_map<
std::type_index,
std::function<bool(std::shared_ptr<Node>, std::shared_ptr<Node>)>>
cse_map{{TI(runtime::cpu::op::ConvertLayout), cse_convertlayout}};
return cse_map;
}
}
}
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <functional>
#include <memory>
#include <set>
#include <sstream>
#include <string>
#include <typeindex>
#include <typeinfo>
#include <unordered_map>
#include "ngraph/node.hpp"
#include "ngraph/pass/manager_state.hpp"
namespace ngraph
{
namespace runtime
{
namespace cpu
{
const std::unordered_map<
std::type_index,
std::function<bool(std::shared_ptr<Node>, std::shared_ptr<Node>)>>&
get_cse_handlers_map();
}
}
}
......@@ -132,6 +132,7 @@
#include "ngraph/runtime/cpu/cpu_backend.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp"
#include "ngraph/runtime/cpu/cpu_call_frame.hpp"
#include "ngraph/runtime/cpu/cpu_cse.hpp"
#include "ngraph/runtime/cpu/cpu_emitter.hpp"
#include "ngraph/runtime/cpu/cpu_executor.hpp"
#include "ngraph/runtime/cpu/cpu_external_function.hpp"
......@@ -1036,7 +1037,7 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(ngraph::pass::Ma
// pass_manager.register_pass<runtime::cpu::pass::ConcatInputs>();
pass_manager.register_pass<runtime::cpu::pass::CPURnnMatFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUBatchFusion>();
pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>();
pass_manager.register_pass<ngraph::pass::CoreFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
// pass_manager.register_pass<runtime::cpu::pass::CPUHorizontalFusion>();
......@@ -1049,6 +1050,8 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(ngraph::pass::Ma
pass_manager.register_pass<runtime::cpu::pass::CPUWorkspaceInsertion>(nv_cwi, false);
pass_manager.register_pass<runtime::cpu::pass::CPUAssignment>(this);
pass_manager.register_pass<runtime::cpu::pass::CPULayout>(this);
pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>(
runtime::cpu::get_cse_handlers_map());
pass_manager.register_pass<runtime::cpu::pass::CPUPostLayoutOptimizations>();
pass_manager.register_pass<runtime::cpu::pass::CPUMemoryOptimization>();
pass_manager.register_pass<ngraph::pass::GetOutputElementElimination>();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment