Commit 3bd00e23 authored by Diego Caballero's avatar Diego Caballero Committed by nmostafa

[MLIR] Enable CompiledKernel as driver for MLIR backend (#18)

This patch leverages CompiledKernel to delimit sub-graphs to be compiled
with MLIR. It introduces a pass that creates a CompiledKernel for the
whole function (for now) and changes MLIRCompiler to align with this new
approach.
parent e3c28fd2
......@@ -383,7 +383,6 @@ if (NGRAPH_CPU_ENABLE)
endif()
if (NGRAPH_MLIR_ENABLE)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DNGRAPH_MLIR_ENABLE")
set(NGRAPH_MLIR_SOURCE_DIR ${CMAKE_SOURCE_DIR}/src/contrib/mlir)
endif()
......
......@@ -19,6 +19,7 @@
#include "ngraph/graph_util.hpp"
#include "ngraph/node_vector.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/runtime/cpu/op/matmul_bias.hpp"
#include "ngraph/type/element_type.hpp"
......@@ -54,6 +55,16 @@ using namespace ngraph::runtime::ngmlir;
namespace ngraph
{
MLIRCompiler::MLIRCompiler(const ngraph::op::CompiledKernel* compiled_kernel,
const std::vector<void*>& external_tensors)
: m_compiled_kernel(compiled_kernel)
, m_external_tensors(external_tensors)
{
NGRAPH_ASSERT((m_compiled_kernel->get_arguments().size() +
m_compiled_kernel->get_kernel_outputs().size()) == external_tensors.size())
<< "Number of arguments and outputs doesn't match number of tensors";
}
void MLIRCompiler::init_mlir()
{
mlir::registerDialect<NGDialect>();
......@@ -77,18 +88,21 @@ namespace ngraph
m_module = make_unique<mlir::Module>(&m_context);
TypeList args_type_list, result_type_list;
build_tensors_list();
NGRAPH_ASSERT(m_ip_tensors.size() != 0) << "Cannot have empty inputs list";
NGRAPH_ASSERT(m_op_tensors.size() != 0) << "Cannot have empty outputs list";
for (auto tensor : m_ip_tensors)
// Retrieve input and output tensors.
const auto& kernel_inputs = m_compiled_kernel->get_arguments();
const auto& kernel_outputs = m_compiled_kernel->get_kernel_outputs();
NGRAPH_ASSERT(kernel_inputs.size() != 0) << "Cannot have empty inputs list";
NGRAPH_ASSERT(kernel_outputs.size() != 0) << "Cannot have empty outputs list";
for (auto input : kernel_inputs)
{
args_type_list.push_back(get_mlir_type(tensor));
args_type_list.push_back(get_mlir_type(input->get_output_tensor_ptr().get()));
}
for (auto tensor : m_op_tensors)
for (auto output : kernel_outputs)
{
result_type_list.push_back(get_mlir_type(tensor));
result_type_list.push_back(get_mlir_type(output->get_output_tensor_ptr().get()));
}
auto func_type = mlir::FunctionType::get(args_type_list, result_type_list, &m_context);
......@@ -98,11 +112,12 @@ namespace ngraph
// populate Tensor->Value maps
int i = 0;
for (auto tensor : m_ip_tensors)
for (auto input : kernel_inputs)
{
mlir::Value* arg = function->getArgument(i);
TensorInfo tensor_info{arg};
m_tensor_to_value_map.insert(TensorToInfo(tensor, tensor_info));
m_tensor_to_value_map.insert(
TensorToInfo(input->get_output_tensor_ptr().get(), tensor_info));
i++;
}
......@@ -116,59 +131,6 @@ namespace ngraph
}
}
void MLIRCompiler::build_tensors_list()
{
for (const auto node : m_sub_graph)
{
// get all nodes output tensors
// if an output has a use out of the subgraph, it is an output tensor, else a temp.
for (auto i = 0; i < node->get_output_size(); i++)
{
const std::set<descriptor::Input*>& inputs = node->get_output_inputs(i);
auto tensor = node->get_output_tensor_ptr(i);
for (auto ip : inputs)
{
bool out_of_subgraph =
(std::find(std::begin(m_sub_graph),
std::end(m_sub_graph),
ip->get_node().get()) == std::end(m_sub_graph));
if (out_of_subgraph)
{
// we found a use out of subgraph, consider this an output tensor
// those would be added as return value for the mlir func
if (std::find(std::begin(m_op_tensors),
std::end(m_op_tensors),
tensor.get()) == std::end(m_op_tensors))
{
m_op_tensors.push_back(tensor.get());
}
}
}
}
// get over all input tensors
for (const auto arg : node->get_arguments())
{
bool out_of_subgraph =
(std::find(std::begin(m_sub_graph), std::end(m_sub_graph), arg.get()) ==
std::end(m_sub_graph));
if (out_of_subgraph)
{
for (auto i = 0; i < arg->get_output_size(); i++)
{
auto tensor = arg->get_output_tensor_ptr(i);
if (std::find(std::begin(m_ip_tensors),
std::end(m_ip_tensors),
tensor.get()) == std::end(m_ip_tensors))
{
m_ip_tensors.push_back(tensor.get());
}
}
}
}
}
}
mlir::Type MLIRCompiler::get_mlir_type(const descriptor::Tensor* tensor)
{
SmallVector<int64_t, 4> shape;
......@@ -249,27 +211,24 @@ namespace ngraph
void MLIRCompiler::build_ng_dialect()
{
// TODO: subgraph_topological_sort expects a list of shared_ptr. CPU BE has raw pointers.
// Fix this.
//for (auto node : subgraph_topological_sort(m_sub_graph))
NGRAPH_ASSERT(m_sub_graph.size() == 1) << "Supporting code-gen for a single node for now";
const NodeVector& sub_graph = m_compiled_kernel->get_node_list();
NGRAPH_ASSERT(sub_graph.size() == 1) << "Supporting code-gen for a single node for now";
auto np = sub_graph[0];
auto it = op_dispatcher.find(TI(*np));
if (it == op_dispatcher.end())
{
auto np = m_sub_graph[0];
auto it = op_dispatcher.find(TI(*np));
if (it == op_dispatcher.end())
{
throw unsupported_op{
std::string{"The MLIR backend doesn't currently implement the '"} +
np->description() + "' operation"};
}
mlir::Value* mlir_value = it->second(*this, np);
// builders that have multiple result values will update the value map, and set their ret values to null
if (mlir_value)
{
update_tensor_value(np->get_output_tensor_ptr().get(), mlir_value);
}
throw unsupported_op{std::string{"The MLIR backend doesn't currently implement the '"} +
np->description() + "' operation"};
}
mlir::Value* mlir_value = it->second(*this, np.get());
// builders that have multiple result values will update the value map, and set their ret values to null
if (mlir_value)
{
update_tensor_value(np->get_output_tensor_ptr().get(), mlir_value);
}
create_return();
}
......@@ -308,9 +267,9 @@ namespace ngraph
void MLIRCompiler::create_return()
{
std::vector<mlir::Value*> value_list;
for (auto tensor : m_op_tensors)
for (auto output : m_compiled_kernel->get_kernel_outputs())
{
value_list.push_back(get_tensor_value(tensor).m_value);
value_list.push_back(get_tensor_value(output->get_output_tensor_ptr().get()).m_value);
}
m_builder->create<NG_ReturnOp>(mlir::UnknownLoc::get(&m_context), value_list);
}
......
......@@ -38,6 +38,10 @@ namespace mlir
namespace ngraph
{
namespace op
{
class CompiledKernel;
}
namespace runtime
{
namespace ngmlir
......@@ -51,12 +55,8 @@ namespace ngraph
using TensorList = std::vector<descriptor::Tensor*>;
using TypeList = llvm::SmallVector<mlir::Type, 4>;
MLIRCompiler(const std::vector<const Node*>& sub_graph,
const std::vector<void*>& external_tensors)
: m_sub_graph(sub_graph.begin(), sub_graph.end())
, m_external_tensors(external_tensors)
{
}
MLIRCompiler(const ngraph::op::CompiledKernel* compiled_kernel,
const std::vector<void*>& external_tensors);
/// Compiles and runs a subgraph in MLIR
void compile_and_run();
......@@ -84,8 +84,6 @@ namespace ngraph
void execute();
void cleanup();
/// Collects input and output tensors to this sub-graph
void build_tensors_list();
mlir::Type get_mlir_type(const descriptor::Tensor* tensor);
mlir::Type get_mlir_type(const element::Type& type);
TensorInfo get_tensor_value(descriptor::Tensor* tensor);
......@@ -122,15 +120,16 @@ namespace ngraph
std::function<mlir::Value*(MLIRCompiler& compiler, const ngraph::Node*)>;
using MLIRCompOpMap = std::unordered_map<std::type_index, MLIRCompOpFunction>;
llvm::SmallVector<const Node*, 4> m_sub_graph;
// Sub-graph to be compiled and executed with MLIR.
const ngraph::op::CompiledKernel* m_compiled_kernel;
// Pointers to externally allocated memory for sub-graph's input and output tensors.
const std::vector<void*>& m_external_tensors;
llvm::SmallVector<void*, 8> m_invoke_args;
// Maps tensor to the value it represents in the IR
// use for MLIR dialect gen
TensorToInfoMap m_tensor_to_value_map;
// List of input and output tensors in the graph
TensorList m_ip_tensors, m_op_tensors;
static const MLIRCompOpMap op_dispatcher;
// Memory manager for temp allocations inside JIT'ed code
......
......@@ -493,6 +493,13 @@ set(SRC ${SRC}
runtime/dynamic/dynamic_backend.hpp
)
# MLIR specific files
set(SRC
${SRC}
pass/mlir_subgraph_extraction.cpp
pass/mlir_subgraph_extraction.hpp
)
if(NGRAPH_JSON_ENABLE)
list(APPEND SRC serializer.cpp serializer.hpp event_tracing.cpp event_tracing.hpp)
endif()
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "mlir_subgraph_extraction.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/op/get_output_element.hpp"
using namespace ngraph::descriptor;
using namespace ngraph::op;
using namespace ngraph::pass;
#define TI(x) std::type_index(typeid(x))
bool MLIRSubgraphExtractionPass::run_on_function(std::shared_ptr<Function> func)
{
// Create a CompiledKernel for all the ops in the function, except Parameters and Results.
NodeVector ck_ops;
for (auto op : func->get_ordered_ops())
{
if (TI(Parameter) != TI(*op) && TI(Result) != TI(*op))
{
ck_ops.push_back(op);
}
}
NodeVector ck_args;
for (auto& param : func->get_parameters())
{
ck_args.push_back(param);
}
NodeVector ck_outputs = std::move(get_subgraph_outputs(ck_ops, {} /*exclusions*/));
NGRAPH_ASSERT(ck_outputs.size() == 1) << "Unsupported subgraph with multiple outputs";
auto ck = std::make_shared<CompiledKernel>(ck_ops, ck_outputs, ck_args);
// Connect CompiledKernel to output nodes by replacing the output descriptors of the output
// nodes.
for (size_t i = 0, end = ck_outputs.size(); i < end; ++i)
{
auto& output_descs = ck_outputs[i]->get_outputs();
NGRAPH_ASSERT(output_descs.size() == 1) << "Unexpected multiple output descriptors";
auto& out_desc = output_descs[0];
// 'replace_output' invalidates iterator of the original container. Use a copy instead.
std::set<Input*> input_descs{out_desc.get_inputs()};
for (Input* in_desc : input_descs)
{
in_desc->replace_output(ck, i);
}
}
return true;
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/pass/pass.hpp"
namespace ngraph
{
namespace pass
{
/// This pass creates CompiledKernel ops enclosing sub-graphs that will be compiled and
/// executed by MLIR.
// TODO: WIP. Currently we only create a single CompiledKernel op for the whole function
// body.
class MLIRSubgraphExtractionPass : public ngraph::pass::FunctionPass
{
public:
MLIRSubgraphExtractionPass() {}
bool run_on_function(std::shared_ptr<Function> func) override;
};
}
}
......@@ -143,6 +143,13 @@ if (NGRAPH_HALIDE)
)
endif()
if (NGRAPH_MLIR_ENABLE)
set(SRC
${SRC}
builder/mlir_cpu_compiled_kernel.cpp
)
endif()
if (NGRAPH_CPU_ENABLE)
set(NGRAPH_CPU_DEBUGINFO_ENABLE 0 CACHE STRING "Enable debuginfo in the CPU backend")
......
//*****************************************************************************
// Copyright 2018-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/runtime/cpu/cpu_builder.hpp"
#include "contrib/mlir/compiler.hpp"
#include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/runtime/cpu/cpu_runtime_context.hpp"
using namespace ngraph;
using namespace ngraph::op;
using namespace ngraph::runtime::cpu;
using namespace ngraph::runtime::ngmlir;
#define TI(x) type_index(typeid(x))
namespace ngraph
{
namespace runtime
{
namespace cpu
{
template <>
void Builder::BUILDER_DECL(CompiledKernel)
{
auto& functors = external_function->get_functors();
// Tensors haven't been allocated yet so we have to keep a pointer to the pointer
// that will hold the future memory address.
std::vector<void**> double_ptr_args;
for (const TensorViewWrapper& arg : args)
{
double_ptr_args.push_back(&external_function->get_tensor_data(arg.get_name()));
}
for (const TensorViewWrapper& result : out)
{
double_ptr_args.push_back(
&external_function->get_tensor_data(result.get_name()));
}
// Create functor that will be executed to compile and run this CompiledKernel.
// Note that 'double_ptr_args' must be captured by value since it's a local var.
auto functor = [node, double_ptr_args](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
// MLIR requires a list of type-erased pointer to arguments. Tensors must have
// been allocated at this point so we can get rid of the extra reference.
std::vector<void*> ptr_args;
for (auto& double_ptr : double_ptr_args)
{
ptr_args.push_back(*double_ptr);
}
// Compile nodes within the CompiledKernel op.
auto* compiled_kernel = static_cast<const CompiledKernel*>(node);
MLIRCompiler mlir_compiler(compiled_kernel, ptr_args);
// TODO: Decouple 'compile' and 'run' APIs. We want to be able to run the same
// jitted code on different arguments.
mlir_compiler.compile_and_run();
};
functors.emplace_back(functor);
}
}
}
}
#undef TI
......@@ -507,34 +507,3 @@ namespace ngraph
}
}
}
// TODO:
// Get rid of the #ifdefs by moving MLIR hooks to separate files in cpu backend
// we can then instead compile them conditionally based on NGRAPH_MLIR_ENABLE cmake flag
#ifdef NGRAPH_MLIR_ENABLE
using namespace ngraph::runtime::ngmlir;
using namespace ngraph::runtime::cpu;
CPUKernelFunctor Builder::build_mlir_single_output_binary_op(const ngraph::Node* node,
void*& arg0_tensor,
void*& arg1_tensor,
void*& out_tensor)
{
// TODO: Remove m_ip/op_list construction out of MLIRCompiler.
auto functor = [&, node](CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
std::vector<const Node*> nodelist = {node};
// MLIR requires a list of type-erased pointer to arguments. Our arguments
// are already pointers, so we need to pass a double pointer.
std::vector<void*> ptr_args = {arg0_tensor, arg1_tensor, out_tensor};
MLIRCompiler mlirc(nodelist, ptr_args);
// TODO: Decouple 'compile' and 'run' APIs. We want to be able to run the
// same jitted code on different arguments.
mlirc.compile_and_run();
};
return functor;
}
#endif
\ No newline at end of file
......@@ -403,12 +403,6 @@ namespace ngraph
const std::vector<TensorViewWrapper>& out)
{
}
// TODO (dcab): Doc
static CPUKernelFunctor build_mlir_single_output_binary_op(const ngraph::Node* node,
void*& arg0_tensor,
void*& arg1_tensor,
void*& out_tensor);
};
}
}
......
......@@ -53,6 +53,7 @@
#include "ngraph/op/erf.hpp"
#include "ngraph/op/exp.hpp"
#include "ngraph/op/experimental/batch_mat_mul.hpp"
#include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/op/experimental/generate_mask.hpp"
#include "ngraph/op/experimental/quantized_avg_pool.hpp"
#include "ngraph/op/experimental/quantized_concat.hpp"
......
......@@ -141,6 +141,7 @@
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/memory_layout.hpp"
#include "ngraph/pass/mlir_subgraph_extraction.hpp"
#include "ngraph/pass/nop_elimination.hpp"
#include "ngraph/pass/propagate_cacheability.hpp"
#include "ngraph/pass/reshape_elimination.hpp"
......@@ -1209,7 +1210,10 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(
#if defined(NGRAPH_HALIDE)
REGISTER_KNOBBED_PASS(HalideSubgraphExtraction, true, ngraph::runtime::cpu::pass);
#endif
if (std::getenv("NGRAPH_MLIR") != nullptr)
{
REGISTER_KNOBBED_PASS(MLIRSubgraphExtractionPass, /*enable by default*/ true, ngraph::pass);
}
NodeVector nv_cwi; // We dont need CPUWorkspaceInsertion to return list of indices
REGISTER_KNOBBED_PASS_WITH_ARGS(CPUWorkspaceInsertion, true, runtime::cpu::pass, nv_cwi, false);
REGISTER_KNOBBED_PASS_WITH_ARGS(CPUAssignment, true, runtime::cpu::pass, this);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment