Unverified Commit fc9a7dea authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Merge pull request #3298 from NervanaSystems/nmostafa/recompile

[MLIR] Re-compile sub-graph once on first invocation
parents 5d3456e4 956e8b3a
......@@ -68,22 +68,11 @@ using llvm::SmallVector;
using llvm::StringRef;
using llvm::make_unique;
using llvm::ArrayRef;
using namespace ngraph::runtime::ngmlir;
#define COMPILE_OP_DECL(op_name) \
create_op<op_name>(MLIRCompiler & compiler, const ngraph::Node* ng_node)
MLIRCompiler::MLIRCompiler(const ngraph::op::CompiledKernel* compiled_kernel,
const std::vector<void*>& external_tensors)
: m_compiled_kernel(compiled_kernel)
, m_external_tensors(external_tensors)
{
NGRAPH_CHECK((m_compiled_kernel->get_arguments().size() +
m_compiled_kernel->get_kernel_outputs().size()) == external_tensors.size(),
"Number of arguments and outputs doesn't match number of tensors");
}
void MLIRCompiler::init_mlir()
{
// Mutex to safely initialize MLIR.
......@@ -101,12 +90,15 @@ void MLIRCompiler::init_mlir()
}
}
void MLIRCompiler::compile_and_run()
void MLIRCompiler::compile()
{
build_ng_dialect_module();
lower_ng_dialect();
optimize();
bind_arguments();
}
void MLIRCompiler::run(std::vector<void*>& external_tensors)
{
bind_arguments(external_tensors);
execute();
cleanup();
}
......@@ -241,9 +233,10 @@ MLIRCompiler::TensorInfo MLIRCompiler::get_tensor_value(descriptor::Tensor* tens
return it->second;
}
// Lowers nGraph dialect to affine dialect.
// Lowers nGraph dialect all the way to LLVM module.
void MLIRCompiler::lower_ng_dialect()
{
// Lower NG dialect to Affine
mlir::PassManager pm;
pm.addPass(mlir::createDialectLoweringPass(this));
pm.addPass(mlir::createCanonicalizerPass());
......@@ -256,13 +249,48 @@ void MLIRCompiler::lower_ng_dialect()
}
dump_mlir_module("Affine Dialect Dump:");
optimize();
NGRAPH_CHECK(m_module, "MLIR module is not ready.");
// Lower Standard dialect to LLVM dialect.
// TODO: Do this via PassManager
mlir::LLVMTypeConverter llvm_converter(&m_context);
OwningRewritePatternList patterns;
mlir::populateStdToLLVMConversionPatterns(llvm_converter, patterns);
mlir::ConversionTarget target(m_context);
target.addLegalDialect<mlir::LLVM::LLVMDialect>();
auto result = applyConversionPatterns(*m_module, target, llvm_converter, std::move(patterns));
NGRAPH_CHECK(succeeded(result), "Standard to LLVM dialect conversion failed");
dump_mlir_module("LLVM-IR Dialect Dump:");
// Lower to LLVM BC and optimize
// Initialize LLVM targets.
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
unsigned opt_level = 3;
if (char* opt_level_str = std::getenv("NGRAPH_MLIR_OPT_LEVEL"))
{
opt_level = std::stoi(opt_level_str);
NGRAPH_CHECK(opt_level >= 0 && opt_level <= 3, "Invalid optimization level");
}
// Create an MLIR execution engine. We use a null MLIR pass manager for now to make sure we
// don't run MLIR passes that were already run. We also pass a default transformer to run
// LLVM optimizations at level 3.
auto llvm_transformer =
mlir::makeOptimizingTransformer(opt_level /*optLevel*/, 0 /*sizeLevel*/);
auto maybeEngine = mlir::ExecutionEngine::create(m_module.get(), llvm_transformer);
NGRAPH_CHECK(maybeEngine, "failed to construct an execution engine");
m_engine = std::move(maybeEngine.get());
}
// Receives affine dialect as input and applies affine and standard dialect based optimizations.
// Lowering from affine dialect to standard dialect happens along the way. Output consists of
// standard dialect only ops.
void MLIRCompiler::optimize()
{
// Lower Affine to Std Dialect
mlir::PassManager pm;
// Lower affine ops
pm.addPass(mlir::createLowerAffinePass());
......@@ -458,33 +486,39 @@ mlir::Operation* MLIRCompiler::create_index_reduction(const ngraph::Node* ng_nod
op->setAttr("axes", red_axes_attr);
return op;
}
// Binds MLIR function arguments to the proper values. This includes externally allocated tensors
// helpers to be used inside the function.
void MLIRCompiler::bind_arguments()
void MLIRCompiler::bind_arguments(std::vector<void*>& external_tensors)
{
NGRAPH_CHECK(m_module, "MLIR module is not ready.");
mlir::Function* func = m_module->getNamedFunction("main");
NGRAPH_CHECK(func && !func->getBlocks().empty(), "Function not found");
// Set external arguments
NGRAPH_CHECK(m_compiled_kernel, "No compiled kernel set for compiler");
NGRAPH_CHECK((m_compiled_kernel->get_arguments().size() +
m_compiled_kernel->get_kernel_outputs().size()) == external_tensors.size(),
"Number of arguments and outputs doesn't match number of tensors");
m_external_tensors = &external_tensors;
// Create list with a type-erased double pointer for each invocation arguments.
// We currently use 'allocateMemRefArguments', which creates a
// SmallVector<StaticFloatMemref*>. StaticFloatMemref is just a struct with the
// actual pointer to the data.
// create MemRef args
auto expected_arguments = allocate_memref_args(func);
auto expected_arguments = allocate_memref_args();
NGRAPH_CHECK(expected_arguments.size(), "Arguments can't be created");
m_invoke_args = std::move(expected_arguments);
NGRAPH_CHECK(m_invoke_args.size() == m_external_tensors.size(),
NGRAPH_CHECK(m_invoke_args.size() == m_external_tensors->size(),
"Number of external tensors doesn't match number of function arguments");
// Assign external tensor pointers to invocation arguments.
for (size_t i = 0, num_args = m_invoke_args.size(); i < num_args; ++i)
{
((mlir::StaticFloatMemRef*)m_invoke_args[i])->data = (float*)m_external_tensors[i];
((mlir::StaticFloatMemRef*)m_invoke_args[i])->data = (float*)(*m_external_tensors)[i];
}
// Add pointer to memory manager
......@@ -501,39 +535,6 @@ void MLIRCompiler::bind_arguments()
// Lowers standard dialect to LLVM dialect and uses the MLIR execution engine to execute the code.
void MLIRCompiler::execute()
{
NGRAPH_CHECK(m_module, "MLIR module is not ready.");
// Lower Standard dialect to LLVM dialect.
mlir::LLVMTypeConverter llvm_converter(&m_context);
OwningRewritePatternList patterns;
mlir::populateStdToLLVMConversionPatterns(llvm_converter, patterns);
mlir::ConversionTarget target(m_context);
target.addLegalDialect<mlir::LLVM::LLVMDialect>();
auto result = applyConversionPatterns(*m_module, target, llvm_converter, std::move(patterns));
NGRAPH_CHECK(succeeded(result), "Standard to LLVM dialect conversion failed");
dump_mlir_module("LLVM-IR Dialect Dump:");
// Initialize LLVM targets.
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
unsigned opt_level = 3;
if (char* opt_level_str = std::getenv("NGRAPH_MLIR_OPT_LEVEL"))
{
opt_level = std::stoi(opt_level_str);
NGRAPH_CHECK(opt_level >= 0 && opt_level <= 3, "Invalid optimization level");
}
// Create an MLIR execution engine. We use a null MLIR pass manager for now to make sure we
// don't run MLIR passes that were already run. We also pass a default transformer to run
// LLVM optimizations at level 3.
auto llvm_transformer =
mlir::makeOptimizingTransformer(opt_level /*optLevel*/, 0 /*sizeLevel*/);
auto maybeEngine = mlir::ExecutionEngine::create(m_module.get(), llvm_transformer);
NGRAPH_CHECK(maybeEngine, "failed to construct an execution engine");
m_engine = std::move(maybeEngine.get());
// Invoke the JIT-compiled function with the arguments. Note that, for API
// uniformity reasons, it takes a list of type-erased pointers to arguments.
// Please, note that 'invoke' method is overloaded with a parameter pack version.
......@@ -560,32 +561,19 @@ void MLIRCompiler::cleanup()
m_mem_mgr.freeAll();
}
SmallVector<void*, 8> MLIRCompiler::allocate_memref_args(mlir::Function* func)
SmallVector<void*, 8> MLIRCompiler::allocate_memref_args()
{
SmallVector<void*, 8> args;
args.reserve(func->getNumArguments());
for (const auto& arg : func->getArguments())
for (auto i = 0; i < m_external_tensors->size(); i++)
{
auto descriptor = allocate_memref_descriptor(arg->getType());
if (!descriptor)
{
continue;
}
auto descriptor = allocate_memref_descriptor();
args.push_back(descriptor);
}
return args;
}
mlir::StaticFloatMemRef* MLIRCompiler::allocate_memref_descriptor(mlir::Type type)
mlir::StaticFloatMemRef* MLIRCompiler::allocate_memref_descriptor()
{
auto memRefType = type.dyn_cast<mlir::MemRefType>();
if (!memRefType)
{
return nullptr;
}
NGRAPH_CHECK(memRefType.getNumDynamicDims() == 0, "No support for dynamic shapes");
// We only use StaticFloatMemRef because that's what MLIR currently offers.
// We should expand this with different types and dynamic MemRefs
auto* descriptor =
......
......@@ -63,11 +63,16 @@ namespace ngraph
using TensorList = std::vector<descriptor::Tensor*>;
using TypeList = llvm::SmallVector<mlir::Type, 4>;
MLIRCompiler(const ngraph::op::CompiledKernel* compiled_kernel,
const std::vector<void*>& external_tensors);
MLIRCompiler(const ngraph::op::CompiledKernel* compiled_kernel)
: m_compiled_kernel(compiled_kernel)
{
}
/// Compiles a subgraph with MLIR
void compile();
/// Compiles and runs a subgraph in MLIR.
void compile_and_run();
/// Executes a pre-compiled subgraph
void run(std::vector<void*>& external_tensors);
/// Returns the memory manager used by this sub-graph compiler.
MLIRMemMgr& get_mem_mgr() { return m_mem_mgr; }
......@@ -88,7 +93,7 @@ namespace ngraph
void build_ng_dialect_module();
void lower_ng_dialect();
void optimize();
void bind_arguments();
void bind_arguments(std::vector<void*>& external_tensors);
void execute();
void cleanup();
......@@ -120,10 +125,10 @@ namespace ngraph
void create_return();
/// Helper to create memref arguments for MLIR function signature
llvm::SmallVector<void*, 8> allocate_memref_args(mlir::Function* func);
llvm::SmallVector<void*, 8> allocate_memref_args();
/// Helper to allocate a mem ref object. Handles static shapes only for now.
mlir::StaticFloatMemRef* allocate_memref_descriptor(mlir::Type type);
mlir::StaticFloatMemRef* allocate_memref_descriptor();
/// Helper to dump MLIR module into llvm::dbgs prepended by the message \p msg.
void dump_mlir_module(const std::string msg);
......@@ -133,7 +138,7 @@ namespace ngraph
const ngraph::op::CompiledKernel* m_compiled_kernel;
// Pointers to externally allocated memory for sub-graph's input and output tensors.
const std::vector<void*>& m_external_tensors;
std::vector<void*>* m_external_tensors;
// Arguments for the MLIR function generated for the nGraph sub-graph.
llvm::SmallVector<void*, 8> m_invoke_args;
......
......@@ -65,14 +65,25 @@ namespace ngraph
{
ptr_args.push_back(ctx->buffer_data[buffer_index]);
}
// Compile nodes within the CompiledKernel op.
auto* compiled_kernel = static_cast<const CompiledKernel*>(node);
CompiledKernel* compiled_kernel =
static_cast<CompiledKernel*>(const_cast<Node*>(node));
bool is_module_ready = true;
auto it = ctx->mlir_compilers.find(compiled_kernel);
if (it == ctx->mlir_compilers.end())
{
// create a new compiler for the CK
ctx->mlir_compilers.emplace(compiled_kernel, compiled_kernel);
is_module_ready = false;
}
MLIRCompiler mlir_compiler(compiled_kernel, ptr_args);
// TODO: Decouple 'compile' and 'run' APIs. We want to be able to run the same
// jitted code on different arguments.
mlir_compiler.compile_and_run();
MLIRCompiler& mlir_compiler = ctx->mlir_compilers.find(compiled_kernel)->second;
if (!is_module_ready)
{
mlir_compiler.compile();
}
mlir_compiler.run(ptr_args);
};
functors.emplace_back(functor);
......
......@@ -25,6 +25,11 @@
#include <tbb/flow_graph.h>
#include <tbb/global_control.h>
#include <tbb/task_scheduler_init.h>
#include "ngraph/op/experimental/compiled_kernel.hpp"
#ifdef NGRAPH_MLIR_ENABLE
#include "contrib/mlir/compiler.hpp"
#endif
namespace mkldnn
{
......@@ -66,6 +71,14 @@ namespace ngraph
State* const* states;
std::set<size_t> breakpoints;
size_t pc;
#ifdef NGRAPH_MLIR_ENABLE
/// Maps CompiledKernel nodes to their MLIR compiler
/// The MLIR compiler caches the compiled code on the first invocation,
/// and may in the future support re-compilation
std::unordered_map<ngraph::op::CompiledKernel*,
ngraph::runtime::ngmlir::MLIRCompiler>
mlir_compilers;
#endif
};
}
......
......@@ -248,3 +248,36 @@ NGRAPH_TEST(${BACKEND_NAME}, mlir_subgraphs_cycle)
EXPECT_TRUE(
test::all_close_f(read_vector<float>(result), vector<float>{70, 80, 90, 136, 164, 192}));
}
NGRAPH_TEST(${BACKEND_NAME}, mlir_multi_call)
{
Shape shape_in1{2, 3};
Shape shape_in2{3, 3};
Shape shape_out{2, 3};
auto A = make_shared<op::Parameter>(element::f32, shape_in1);
auto B = make_shared<op::Parameter>(element::f32, shape_in2);
auto dot = make_shared<op::Dot>(A, B);
auto C = make_shared<op::Parameter>(element::f32, shape_in1);
auto add = make_shared<op::Add>(dot, C);
auto f = make_shared<Function>(add, ParameterVector{A, B, C});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
shared_ptr<runtime::Tensor> a = backend->create_tensor(element::f32, shape_in1);
shared_ptr<runtime::Tensor> b = backend->create_tensor(element::f32, shape_in2);
shared_ptr<runtime::Tensor> c = backend->create_tensor(element::f32, shape_in1);
shared_ptr<runtime::Tensor> result = backend->create_tensor(element::f32, shape_out);
copy_data(a, vector<float>{1.f, 2.f, 3.f, 4.f, 5.f, 6.f});
copy_data(b, vector<float>{1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f});
copy_data(c, vector<float>{5.f, 4.f, 3.f, 2.f, 1.f, 0.f});
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a, b, c});
handle->call_with_validate({result}, {a, b, c});
handle->call_with_validate({result}, {a, b, c});
handle->call_with_validate({result}, {a, b, c});
EXPECT_TRUE(test::all_close_f(read_vector<float>(result),
vector<float>{35.f, 40.f, 45.f, 68.f, 82.f, 96.f}));
}
\ No newline at end of file
......@@ -30,7 +30,6 @@
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/dequantize.hpp"
#include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/op/experimental/generate_mask.hpp"
#include "ngraph/op/experimental/quantized_concat.hpp"
#include "ngraph/op/experimental/quantized_conv_bias.hpp"
......@@ -1544,241 +1543,6 @@ TEST(cpu_fusion, backwards_maxpool_with_indices_n4_c1_hw4_2x2_max)
EXPECT_TRUE(test::all_close_f(read_vector<float>(output), expected, MIN_FLOAT_TOLERANCE_BITS));
}
#if defined(NGRAPH_HALIDE)
TEST(cpu_fusion, compiled_kernel_one_input_one_output_halide)
{
Shape shapeA{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shapeA);
auto relu_a = make_shared<op::Relu>(A);
auto relu_relu_a = make_shared<op::Relu>(relu_a);
auto ck = make_shared<op::CompiledKernel>(
NodeVector{relu_a, relu_relu_a}, NodeVector{relu_relu_a}, NodeVector{A});
auto f = make_shared<Function>(NodeVector{ck}, ParameterVector{A});
auto backend = runtime::Backend::create("CPU");
shared_ptr<runtime::Tensor> a = backend->create_tensor(element::f32, shapeA);
shared_ptr<runtime::Tensor> result = backend->create_tensor(element::f32, shapeA);
vector<float> dataA{-1, 4, -1, 4};
copy_data(a, dataA);
vector<float> expected{0, 4, 0, 4};
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
EXPECT_TRUE(test::all_close(read_vector<float>(result), expected));
}
TEST(cpu_fusion, compiled_kernel_two_input_two_output_halide)
{
Shape shapeA{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shapeA);
auto B = make_shared<op::Parameter>(element::f32, shapeA);
auto relu_a = make_shared<op::Relu>(A);
auto add_ab = make_shared<op::Add>(relu_a, B);
auto ck = make_shared<op::CompiledKernel>(
NodeVector{relu_a, add_ab}, NodeVector{relu_a, add_ab}, NodeVector{A, B});
auto goe1 = make_shared<op::GetOutputElement>(ck, 0);
auto goe2 = make_shared<op::GetOutputElement>(ck, 1);
auto f = make_shared<Function>(NodeVector{goe1, goe2}, ParameterVector{A, B});
auto backend = runtime::Backend::create("CPU");
shared_ptr<runtime::Tensor> a = backend->create_tensor(element::f32, shapeA);
shared_ptr<runtime::Tensor> b = backend->create_tensor(element::f32, shapeA);
shared_ptr<runtime::Tensor> result_relu = backend->create_tensor(element::f32, shapeA);
shared_ptr<runtime::Tensor> result_add = backend->create_tensor(element::f32, shapeA);
vector<float> dataA{-1, 4, -1, 4};
vector<float> dataB{0, 4, 0, 4};
copy_data(a, dataA);
copy_data(b, dataB);
vector<float> expected_relu{0, 4, 0, 4};
vector<float> expected_add{4, 4, 4, 4};
auto handle = backend->compile(f);
handle->call_with_validate({result_relu, result_add}, {a, b});
EXPECT_TRUE(test::all_close(read_vector<float>(result_relu), expected_relu));
}
TEST(cpu_fusion, compiled_kernel_embedded_graph_halide)
{
Shape shapeA{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shapeA);
auto B = make_shared<op::Parameter>(element::f32, shapeA);
auto neg_a = make_shared<op::Negative>(A);
auto neg_b = make_shared<op::Negative>(B);
auto add = neg_a + neg_b;
auto ck =
make_shared<op::CompiledKernel>(NodeVector{add}, NodeVector{add}, NodeVector{neg_a, neg_b});
auto f = make_shared<Function>(NodeVector{ck}, ParameterVector{A, B});
auto backend = runtime::Backend::create("CPU");
shared_ptr<runtime::Tensor> a = backend->create_tensor(element::f32, shapeA);
shared_ptr<runtime::Tensor> b = backend->create_tensor(element::f32, shapeA);
shared_ptr<runtime::Tensor> result = backend->create_tensor(element::f32, shapeA);
vector<float> dataA{1, 4, 1, 4};
copy_data(a, dataA);
vector<float> dataB{1, 2, 3, 4};
copy_data(b, dataB);
vector<float> expected{-2, -6, -4, -8};
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a, b});
EXPECT_TRUE(test::all_close_f(read_vector<float>(result), expected, MIN_FLOAT_TOLERANCE_BITS));
}
TEST(cpu_fusion, compiled_kernel_two_inputs_one_output_halide)
{
Shape shapeA{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shapeA);
auto B = make_shared<op::Parameter>(element::f32, shapeA);
auto add = A + B;
auto ck = make_shared<op::CompiledKernel>(NodeVector{add}, NodeVector{add}, NodeVector{A, B});
auto f = make_shared<Function>(NodeVector{ck}, ParameterVector{A, B});
auto backend = runtime::Backend::create("CPU");
shared_ptr<runtime::Tensor> a = backend->create_tensor(element::f32, shapeA);
shared_ptr<runtime::Tensor> b = backend->create_tensor(element::f32, shapeA);
shared_ptr<runtime::Tensor> result = backend->create_tensor(element::f32, shapeA);
vector<float> dataA{1, 4, 1, 4};
copy_data(a, dataA);
vector<float> dataB{1, 2, 3, 4};
copy_data(b, dataB);
vector<float> expected{2, 6, 4, 8};
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a, b});
EXPECT_TRUE(test::all_close_f(read_vector<float>(result), expected, MIN_FLOAT_TOLERANCE_BITS));
}
TEST(cpu_fusion, compiled_kernel_multiple_outputs_halide)
{
Shape shapeA{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shapeA);
auto B = make_shared<op::Parameter>(element::f32, shapeA);
auto C = make_shared<op::Parameter>(element::f32, shapeA);
auto D = make_shared<op::Parameter>(element::f32, shapeA);
auto neg_a = make_shared<op::Negative>(A);
auto neg_b = make_shared<op::Negative>(B);
auto add_ab = neg_a + neg_b;
auto add_cd = C + B;
auto add_cd_abs = make_shared<op::Abs>(add_cd);
auto add_ab_abs = make_shared<op::Abs>(add_ab);
auto add_aab = add_ab_abs + A;
auto add_cdd = add_cd_abs + D;
auto ck = make_shared<op::CompiledKernel>(
NodeVector{neg_a, neg_b, add_ab, add_cd, add_cd_abs, add_ab_abs, add_aab, add_cdd},
NodeVector{add_aab, add_cdd, neg_b},
NodeVector{A, B, C, D});
auto add_aab_goe = std::make_shared<op::GetOutputElement>(ck, 0);
auto add_cdd_goe = std::make_shared<op::GetOutputElement>(ck, 1);
auto neg_b_goe = std::make_shared<op::GetOutputElement>(ck, 2);
auto f = make_shared<Function>(NodeVector{add_aab_goe, add_cdd_goe, neg_b_goe},
ParameterVector{A, B, C, D});
auto backend = runtime::Backend::create("CPU");
shared_ptr<runtime::Tensor> a = backend->create_tensor(element::f32, shapeA);
shared_ptr<runtime::Tensor> b = backend->create_tensor(element::f32, shapeA);
shared_ptr<runtime::Tensor> c = backend->create_tensor(element::f32, shapeA);
shared_ptr<runtime::Tensor> d = backend->create_tensor(element::f32, shapeA);
shared_ptr<runtime::Tensor> r1 = backend->create_tensor(element::f32, shapeA);
shared_ptr<runtime::Tensor> r2 = backend->create_tensor(element::f32, shapeA);
shared_ptr<runtime::Tensor> r3 = backend->create_tensor(element::f32, shapeA);
vector<float> dataA{1, 4, 1, 4};
vector<float> dataB{3, 3, 3, 9};
vector<float> dataC{1, 2, 3, 4};
vector<float> dataD{-2, 2, -1, 1};
copy_data(a, dataA);
copy_data(b, dataB);
copy_data(c, dataC);
copy_data(d, dataD);
auto handle = backend->compile(f);
handle->call_with_validate({r1, r2, r3}, {a, b, c, d});
vector<float> expected1{5, 11, 5, 17};
vector<float> expected2{2, 7, 5, 14};
vector<float> expected3{-3, -3, -3, -9};
EXPECT_TRUE(test::all_close_f(read_vector<float>(r1), expected1, MIN_FLOAT_TOLERANCE_BITS));
EXPECT_TRUE(test::all_close_f(read_vector<float>(r2), expected2, MIN_FLOAT_TOLERANCE_BITS));
EXPECT_TRUE(test::all_close_f(read_vector<float>(r3), expected3, MIN_FLOAT_TOLERANCE_BITS));
}
TEST(cpu_fusion, compiled_kernel_copy_with_new_args)
{
Shape shapeA{2, 2};
auto A = make_shared<op::Parameter>(element::i32, shapeA);
auto B = make_shared<op::Parameter>(element::i32, shapeA);
auto C = make_shared<op::Parameter>(element::i32, shapeA);
auto D = make_shared<op::Parameter>(element::i32, shapeA);
auto neg_a = make_shared<op::Negative>(A);
auto neg_b = make_shared<op::Negative>(B);
auto add_ab = neg_a + neg_b;
auto add_cd = C + B;
auto add_cd_abs = make_shared<op::Abs>(add_cd);
auto add_ab_abs = make_shared<op::Abs>(add_ab);
auto add_aab = add_ab_abs + A;
auto add_cdd = add_cd_abs + D;
auto ck = make_shared<op::CompiledKernel>(
NodeVector{neg_a, neg_b, add_ab, add_cd, add_cd_abs, add_ab_abs, add_aab, add_cdd},
NodeVector{add_aab, add_cdd, neg_b},
NodeVector{A, B, C, D});
auto add_aab_goe = std::make_shared<op::GetOutputElement>(ck, 0);
auto add_cdd_goe = std::make_shared<op::GetOutputElement>(ck, 1);
auto neg_b_goe = std::make_shared<op::GetOutputElement>(ck, 2);
auto f = make_shared<Function>(NodeVector{add_aab_goe, add_cdd_goe, neg_b_goe},
ParameterVector{A, B, C, D});
auto copy_f = clone_function(*f);
auto backend = runtime::Backend::create("CPU");
shared_ptr<runtime::Tensor> a = backend->create_tensor(element::i32, shapeA);
shared_ptr<runtime::Tensor> b = backend->create_tensor(element::i32, shapeA);
shared_ptr<runtime::Tensor> c = backend->create_tensor(element::i32, shapeA);
shared_ptr<runtime::Tensor> d = backend->create_tensor(element::i32, shapeA);
shared_ptr<runtime::Tensor> r1 = backend->create_tensor(element::i32, shapeA);
shared_ptr<runtime::Tensor> r2 = backend->create_tensor(element::i32, shapeA);
shared_ptr<runtime::Tensor> r3 = backend->create_tensor(element::i32, shapeA);
shared_ptr<runtime::Tensor> copy_r1 = backend->create_tensor(element::i32, shapeA);
shared_ptr<runtime::Tensor> copy_r2 = backend->create_tensor(element::i32, shapeA);
shared_ptr<runtime::Tensor> copy_r3 = backend->create_tensor(element::i32, shapeA);
vector<int> dataA{1, 4, 1, 4};
vector<int> dataB{3, 3, 3, 9};
vector<int> dataC{1, 2, 3, 4};
vector<int> dataD{-2, 2, -1, 1};
copy_data(a, dataA);
copy_data(b, dataB);
copy_data(c, dataC);
copy_data(d, dataD);
auto handle = backend->compile(f);
handle->call_with_validate({r1, r2, r3}, {a, b, c, d});
auto h1 = backend->compile(copy_f);
h1->call_with_validate({copy_r1, copy_r2, copy_r3}, {a, b, c, d});
EXPECT_EQ(read_vector<int>(r1), read_vector<int>(copy_r1));
EXPECT_EQ(read_vector<int>(r2), read_vector<int>(copy_r2));
EXPECT_EQ(read_vector<int>(r3), read_vector<int>(copy_r3));
}
#endif
static std::shared_ptr<ngraph::Function> make_forward_function()
{
Shape shape_a{10, 3, 28, 28};
......@@ -2298,202 +2062,6 @@ TEST(cpu_fusion, rnn_fprop_1_lstm_cell)
EXPECT_TRUE(test::all_close(expected_ct, read_vector<float>(result_ct)));
}
#if 0
TEST(cpu_fusion, compiled_kernel_fusion_multiple_groups_pruned)
{
auto make_function = []() -> std::shared_ptr<Function> {
Shape shape{};
auto a = make_shared<op::Parameter>(element::f32, shape);
auto b = make_shared<op::Parameter>(element::f32, shape);
auto c = make_shared<op::Parameter>(element::f32, shape);
auto add_ab = a + b;
auto add_abs = std::make_shared<op::Abs>(add_ab);
auto abs_neg = std::make_shared<op::Negative>(add_abs);
auto sub_c_neg = c - abs_neg;
auto d = make_shared<op::Parameter>(element::f32, shape);
auto d_abs = std::make_shared<op::Abs>(d);
auto add_d = d_abs + add_ab;
auto neg_d = std::make_shared<op::Negative>(add_d);
auto mul_cd = neg_d * sub_c_neg;
auto f =
std::make_shared<Function>(ngraph::NodeVector{mul_cd}, ParameterVector{a, b, c, d});
return f;
};
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUCompiledKernelFusion>(3);
auto cpu_f = make_function();
auto int_f = make_function();
pass_manager.run_passes(cpu_f);
test::Uniform<float> rng(-100.0f, 100.0f);
vector<vector<float>> args;
size_t ckn = count_ops_of_type<op::CompiledKernel>(cpu_f);
ASSERT_GT(ckn, 0);
for (shared_ptr<op::Parameter> param : cpu_f->get_parameters())
{
vector<float> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
auto int_results = execute(int_f, args, "INTERPRETER");
auto cpu_results = execute(cpu_f, args, "CPU");
for (size_t i = 0; i < cpu_results.size(); i++)
{
EXPECT_TRUE(test::all_close(cpu_results.at(i), int_results.at(i), 1.0e-4f, 1.0e-4f));
}
}
TEST(cpu_fusion, compiled_kernel_fusion_bounded_relu)
{
auto make_function = []() -> std::shared_ptr<Function> {
Shape shape{};
auto a = make_shared<op::Parameter>(element::f32, shape);
auto relu = make_shared<op::Relu>(a);
auto upper_bound =
op::Constant::create<float>(element::f32, shape, std::vector<float>{6.0f});
auto minn = make_shared<op::Minimum>(relu, upper_bound);
auto absn = make_shared<op::Abs>(minn);
auto negn = std::make_shared<op::Negative>(absn);
auto f = std::make_shared<Function>(ngraph::NodeVector{negn}, ParameterVector{a});
return f;
};
pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>("before_relu_fusion.png");
pass_manager.register_pass<runtime::cpu::pass::CPUCompiledKernelFusion>(3);
pass_manager.register_pass<pass::VisualizeTree>("after_relu_fusion.png");
auto cpu_f = make_function();
auto int_f = make_function();
pass_manager.run_passes(cpu_f);
test::Uniform<float> rng(-100.0f, 100.0f);
vector<vector<float>> args;
size_t ckn = count_ops_of_type<op::CompiledKernel>(cpu_f);
ASSERT_GT(ckn, 0);
for (shared_ptr<op::Parameter> param : cpu_f->get_parameters())
{
vector<float> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
auto int_results = execute(int_f, args, "INTERPRETER");
auto cpu_results = execute(cpu_f, args, "CPU");
for (size_t i = 0; i < cpu_results.size(); i++)
{
EXPECT_TRUE(test::all_close(cpu_results.at(i), int_results.at(i), 1.0e-4f, 1.0e-4f));
}
}
TEST(cpu_fusion, compiled_kernel_fusion_multiple_groups)
{
auto make_function = []() -> std::shared_ptr<Function> {
Shape shape{};
auto a = make_shared<op::Parameter>(element::f32, shape);
auto b = make_shared<op::Parameter>(element::f32, shape);
auto c = make_shared<op::Parameter>(element::f32, shape);
auto add_ab = a + b;
auto add_abs = std::make_shared<op::Abs>(add_ab);
auto abs_neg = std::make_shared<op::Negative>(add_abs);
auto sub_c_neg = c - abs_neg;
auto d = make_shared<op::Parameter>(element::f32, shape);
auto d_abs = std::make_shared<op::Abs>(d);
auto add_d = d_abs + add_ab;
auto neg_d = std::make_shared<op::Negative>(add_d);
auto mul_cd = neg_d * sub_c_neg;
auto f =
std::make_shared<Function>(ngraph::NodeVector{mul_cd}, ParameterVector{a, b, c, d});
return f;
};
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUCompiledKernelFusion>(2);
auto cpu_f = make_function();
auto int_f = make_function();
pass_manager.run_passes(cpu_f);
test::Uniform<float> rng(-100.0f, 100.0f);
vector<vector<float>> args;
size_t ckn = count_ops_of_type<op::CompiledKernel>(cpu_f);
ASSERT_GT(ckn, 0);
for (shared_ptr<op::Parameter> param : cpu_f->get_parameters())
{
vector<float> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
auto int_results = execute(int_f, args, "INTERPRETER");
auto cpu_results = execute(cpu_f, args, "CPU");
for (size_t i = 0; i < cpu_results.size(); i++)
{
EXPECT_TRUE(test::all_close(cpu_results.at(i), int_results.at(i), 1.0e-4f, 1.0e-4f));
}
}
TEST(cpu_fusion, compiled_kernel_fusion_one_group)
{
auto make_function = []() -> std::shared_ptr<Function> {
Shape shape{};
auto a = make_shared<op::Parameter>(element::f32, shape);
auto b = make_shared<op::Parameter>(element::f32, shape);
auto c = make_shared<op::Parameter>(element::f32, shape);
auto add_ab = a + b;
auto add_abs = std::make_shared<op::Abs>(add_ab);
auto abs_neg = std::make_shared<op::Negative>(add_abs);
auto sub_c_neg = c - abs_neg;
auto d = make_shared<op::Parameter>(element::f32, shape);
auto add_d = sub_c_neg + d;
auto abs_add_d = std::make_shared<op::Abs>(add_d);
auto e = make_shared<op::Parameter>(element::f32, shape);
auto add_e = e + abs_add_d;
auto neg_e = std::make_shared<op::Negative>(add_e);
auto f = std::make_shared<Function>(ngraph::NodeVector{neg_e},
ParameterVector{a, b, c, d, e});
return f;
};
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUCompiledKernelFusion>(2);
auto cpu_f = make_function();
auto int_f = make_function();
pass_manager.run_passes(cpu_f);
test::Uniform<float> rng(-100.0f, 100.0f);
vector<vector<float>> args;
size_t ckn = count_ops_of_type<op::CompiledKernel>(cpu_f);
ASSERT_GT(ckn, 0);
for (shared_ptr<op::Parameter> param : cpu_f->get_parameters())
{
vector<float> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
auto int_results = execute(int_f, args, "INTERPRETER");
auto cpu_results = execute(cpu_f, args, "CPU");
for (size_t i = 0; i < cpu_results.size(); i++)
{
EXPECT_TRUE(test::all_close(cpu_results.at(i), int_results.at(i), 1.0e-4f, 1.0e-4f));
}
}
#endif
void sigmoid_multiply_fusion_forward_compute(runtime::Backend* backend,
const ParameterVector& input_params,
const vector<vector<float>>& input_data,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment