Commit 14a2aeae authored by Diego Caballero's avatar Diego Caballero Committed by Scott Cyphers

[Standalone] Decouple MKLDNN primitive build from code generation (#2701)

* [Standalone] Decouple MKLDNN primitive build from code generation

This patch introduces a new pass, MKLDNNPrimitiveBuildPass, which
iterates over all the ops assigned to MKLDNN and builds their
corresponding primitives. Primitive indexes are stored in MKLDNNEmitter
and can easily be retrieved with the get_primitive_index(node)
interface. This decouples the creation of primitives from codegen and
fixes the problem of MKLDNN primitives being created twice
(CommonFunctionCollection pass and codegen).

Current assertions only allow the creation of a single primitive per
node but it should be simple to remove this when needed. Using a pass
might not be the best approach here but I found it convenient for the
current needs and it should be straightforward to convert into a utility,
if needed.

These changes caused a conflict with recently introduced
'build_quantized_inner_product*'. These new build methods will be ported
in a follow up patch to new build approach.

* Remove unrelated comment

* Remove TensorView code

* Set m_node_primitive_map from MKLDNNPrimitiveBuildPass

* Move node->primitive map from mkldnn pass to external function

* Fix struct/class inconsistency en fw declaration
parent 6a0101a2
......@@ -113,6 +113,7 @@ set(SRC
pass/cpu_mat_fusion.cpp
pass/cpu_memory_assignment.cpp
pass/cpu_memory_optimization.cpp
pass/cpu_mkldnn_primitive_build.cpp
pass/cpu_post_layout_optimizations.cpp
pass/cpu_rnn_fusion.cpp
pass/cpu_workspace_insertion.cpp
......
This diff is collapsed.
......@@ -175,6 +175,7 @@
#include "ngraph/runtime/cpu/pass/cpu_mat_fusion.hpp"
#include "ngraph/runtime/cpu/pass/cpu_memory_assignment.hpp"
#include "ngraph/runtime/cpu/pass/cpu_memory_optimization.hpp"
#include "ngraph/runtime/cpu/pass/cpu_mkldnn_primitive_build.hpp"
#include "ngraph/runtime/cpu/pass/cpu_post_layout_optimizations.hpp"
#include "ngraph/runtime/cpu/pass/cpu_rnn_fusion.hpp"
#include "ngraph/runtime/cpu/pass/cpu_workspace_insertion.hpp"
......@@ -464,6 +465,11 @@ void runtime::cpu::CPU_ExternalFunction::compile(ngraph::pass::PassConfig& pass_
ngraph::pass::Manager pass_manager;
register_common_passes(pass_manager, pass_config);
// Build mkldnn primitives for codegen.
pass_manager.register_pass<runtime::cpu::pass::MKLDNNPrimitiveBuildPass>(
*m_mkldnn_emitter, m_node_primitive_idx_map);
unordered_map<Node*, Node*> node_function_map;
string common_function_string;
auto femitter = bind(&ngraph::runtime::cpu::CPU_ExternalFunction::emit_op_as_function,
......@@ -1156,6 +1162,7 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(
pass_config.get_pass_attribute("ReuseMemory");
pass_manager.register_pass<runtime::cpu::pass::CPUMemoryAssignment>(
bufferID_to_tensorSets, tensor_to_bufferID, size_t(s_memory_pool_alignment), !reuse_memory);
pass_manager.get_state().set_visualize_tree_ops_map(runtime::cpu::get_visualize_tree_ops_map());
}
......
......@@ -114,6 +114,16 @@ namespace ngraph
return m_mkldnn_emitter;
}
/// Returns the index of the mkldnn primitive previously created for \p node.
size_t get_primitive_index(const Node* node) const
{
auto it = m_node_primitive_idx_map.find(node);
NGRAPH_ASSERT(it != m_node_primitive_idx_map.end())
<< "Primitive not found for node " << node->description();
return it->second;
}
size_t add_state(ngraph::State* state)
{
m_states.push_back(state);
......@@ -296,6 +306,9 @@ namespace ngraph
std::unordered_map<std::string, int> subgraph_param_sizes;
std::unordered_map<std::string, std::reference_wrapper<void*>> subgraph_param_ptrs;
#endif
/// Map each node with mkldnn implementation to its mkldnn primitive index.
std::unordered_map<const Node*, size_t> m_node_primitive_idx_map;
};
}
}
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/pass/pass.hpp"
#include <functional>
#include <typeindex>
#include <unordered_map>
#define BUILD_PRIMITIVE_DECL(op_name) \
build_primitive<op_name>(ngraph::runtime::cpu::MKLDNNEmitter & mkldnn_emitter, \
ngraph::Node * node)
namespace mkldnn
{
class primitive;
}
namespace ngraph
{
class Node;
namespace runtime
{
namespace cpu
{
class MKLDNNEmitter;
namespace pass
{
using PrimitiveBuildFunction =
std::function<size_t(ngraph::runtime::cpu::MKLDNNEmitter&, ngraph::Node*)>;
using PrimitiveBuildOpMap =
std::unordered_map<std::type_index, PrimitiveBuildFunction>;
/// This pass traverses the call graph and creates MKLDNN primitives for those ops
/// that have been assigned to MKLDNN.
class MKLDNNPrimitiveBuildPass : public ngraph::pass::CallGraphPass
{
private:
ngraph::runtime::cpu::MKLDNNEmitter& m_mkldnn_emitter;
/// External map to store each node with mkldnn implementation and its mkldnn
/// associated primitive index.
std::unordered_map<const Node*, size_t>& m_node_primitive_idx_map;
public:
MKLDNNPrimitiveBuildPass(
ngraph::runtime::cpu::MKLDNNEmitter& mkldnn_emitter,
std::unordered_map<const Node*, size_t>& node_primitive_idx_map)
: m_mkldnn_emitter(mkldnn_emitter)
, m_node_primitive_idx_map(node_primitive_idx_map)
{
}
bool run_on_call_graph(const std::list<std::shared_ptr<Node>>& nodes) override;
template <typename OP>
static size_t
build_primitive(ngraph::runtime::cpu::MKLDNNEmitter& mkldnn_emitter,
ngraph::Node* node)
{
throw std::runtime_error("Unimplemented op '" + node->description() +
"' in MKLDNNPrimitiveBuildPass");
}
};
}
}
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment