Unverified Commit 925087ba authored by Pruthvi's avatar Pruthvi Committed by GitHub

[MLIR] MatMulBias Fused Op support in MlIR (#4104)

* - add fused_op.td to CmakeLists
- define pattern to fuse Wx + b and to replace with MatMulBias

* - remove table-gen LLVM_TARGET_DEFINATION for fused_ops_pattern.td,
fused_ops.td
- fix build issues

* - change pattern to to match MatMul instead of Dot
- support in CMake to register MatMulBias fused Op pattern

* - made changes to fusion pattern to match Add( Dot (op1, op2), bias) for
MatmulBias
- use applyPatternsGreedily instead of applyFullConversion in the graph
pass
- add unit test inter v/s CPU for MatMulBias

* - Affine lowering, verifier logic to NgMatMulBiasOp

* add missing header file

* - WIP, use NGGemm instead of NGMatMulBias

* -undo unintended changes

* Addressed PR comments

* - refactor the ctor of the NgDialectFusion pass
- register NgDialectFusion pass with the PassRegistration

* Address PR comments

* -add lit test for matmul+bias fusion

* -style fix lit test
Co-authored-by: 's avatarSang Ik Lee <sang.ik.lee@intel.com>
parent fa0e7a4d
......@@ -34,6 +34,8 @@ set(SRC
core/pass/mlir_subgraph_extraction.hpp
core/pass/ng_dialect_builder.cpp
core/pass/ng_dialect_builder.hpp
core/pass/ng_dialect_fused_ops.cpp
core/pass/ng_dialect_fused_ops.hpp
runtime/cpu/memory_manager.cpp
runtime/cpu/cpu_runtime.cpp
runtime/cpu/cpu_callbacks.cpp
......@@ -125,7 +127,12 @@ ngraph_tablegen(ops_attributes.h.inc -gen-enum-decls)
ngraph_tablegen(ops_attributes.cpp.inc -gen-enum-defs)
add_public_tablegen_target(ngraph_ops_attributes_gen)
add_dependencies(mlir_backend ngraph_ops_gen ngraph_ops_interfaces_gen ngraph_ops_attributes_gen)
# tabel-gen ops fused_ops_pattern.td
set(LLVM_TARGET_DEFINITIONS core/pass/fused_ops_pattern.td)
ngraph_tablegen(fused_ops_pattern.h.inc -gen-rewriters)
add_public_tablegen_target(ngraph_ops_pattern_gen)
add_dependencies(mlir_backend ngraph_ops_gen ngraph_ops_interfaces_gen ngraph_ops_attributes_gen ngraph_ops_pattern_gen)
target_include_directories(mlir_backend PRIVATE ${CMAKE_CURRENT_BINARY_DIR})
......
......@@ -23,6 +23,7 @@
#include "ngraph_dialect/ops.hpp"
#include "ngraph_dialect/type.hpp"
#include "pass/ng_dialect_builder.hpp"
#include "pass/ng_dialect_fused_ops.hpp"
#include "ngraph/check.hpp"
#include "ngraph/descriptor/tensor.hpp"
......@@ -117,4 +118,27 @@ void MLIRCompiler::buildNgDialectModule()
}
dumpMlirModule("nGraph Dialect Construction", m_module.get());
optimizeNgDialect();
}
void MLIRCompiler::optimizeNgDialect()
{
mlir::PassManager pm(&m_context);
pm.addPass(ngraph::pass::createNgDialectFusedOpsPass());
// Apply any generic pass manager command line options.
mlir::applyPassManagerCLOptions(pm);
if (failed(pm.run(m_module.get())))
{
NGRAPH_CHECK(false, "MLIR pass manager failed");
}
if (failed(m_module->verify()))
{
NGRAPH_CHECK(false, "Invalid module after NG dialect optimization");
}
dumpMlirModule("nGraph Dialect optimization", m_module.get());
}
......@@ -79,7 +79,7 @@ namespace ngraph
// Converts an nGraph sub-graph to MLIR nGraph dialect.
void buildNgDialectModule();
// Applies any nGraph dialect optimizations
void optimizeNgDialect() { /*TODO: Add Core NG dialect optimizations */}
void optimizeNgDialect();
private:
// Sub-graph to be compiled and executed with MLIR.
......
......@@ -21,7 +21,7 @@
#ifdef NG_FUSED_OPS
#else
#define NG_FUSED_OPS
// Squeeze Op
def NGSqueezeOp :
NG_OneResult_Op<"squeeze", [NoSideEffect, DeclareOpInterfaceMethods<FusedOp>]>,
......@@ -1009,6 +1009,4 @@ def NGConvBiasAddOp :
void setWithRelu(const Attribute& attr) {this->setAttr("withRelu", attr); }
}];
}
#endif //NG_FUSED_OPS
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
//===----------------------------------------------------------------------===//
//
// nGraph Dialect pattern match definitions for fused Op's using DRR
//
// This files declares nGraph fused operations that table-gen uses to create
// C++ code. For more information about tablegen,
// See https://llvm.org/docs/TableGen/index.html
//
// The output file fused_ops_pattern.h.inc is generated at build time
// Each def will corresponding to a C++ class
// NOTE: This file follows nGraph format style and MLIR naming convention since
// it does not expose public API to the rest of nGraph codebase and heavily
// depends on MLIR API.
//
//===----------------------------------------------------------------------===//
include "mlir/IR/OpBase.td"
include "core/ngraph_dialect/ops.td"
// Native code call to c++ helper funnction to create SGEMM Op
def createSgemmOp : NativeCodeCall<"createSgemmOp($_builder, $0.getDefiningOp(), $1, $2, $3)">;
// class for defining the pattern to fuse MatMul + Bias
def MatMulBiasPattern : Pat<(NGAddOp:$old_op ( NGDotOp $input1, $input2), $add_input),
(createSgemmOp $old_op, $input1, $input2, $add_input)>;
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
// NOTE: This file follows nGraph format style.
// Follows nGraph naming convention for public APIs only, else MLIR naming convention.
#include "ng_dialect_fused_ops.hpp"
#include "contrib/mlir/core/ngraph_dialect/dialect.hpp"
#include "contrib/mlir/core/ngraph_dialect/ops.hpp"
#include "contrib/mlir/core/ngraph_dialect/type.hpp"
#include <llvm/IR/Module.h>
#include <mlir/EDSC/Builders.h>
#include <mlir/EDSC/Helpers.h>
#include <mlir/EDSC/Intrinsics.h>
#include <mlir/IR/IntegerSet.h>
#include <mlir/IR/MLIRContext.h>
#include <mlir/IR/StandardTypes.h>
#include <mlir/Pass/PassManager.h>
#include <mlir/Transforms/DialectConversion.h>
#include <mlir/Transforms/Passes.h>
#include <iostream>
using llvm::SmallVector;
using llvm::StringRef;
using llvm::ArrayRef;
using namespace ngraph;
using namespace mlir;
using namespace mlir::edsc;
using namespace mlir::edsc::op;
#define PASS_NAME "fuse-ngraph-dialect"
#define DEBUG_TYPE PASS_NAME
namespace mlir
{
static Value createSgemmOp(
PatternRewriter& rewriter, Operation* old_op, Value input1, Value input2, Value input3)
{
auto castedOp0 = dyn_cast_or_null<NGAddOp>(old_op);
SmallVector<Value, 4> values{input1, input2, input3};
SmallVector<NamedAttribute, 4> attrs;
attrs.emplace_back(
rewriter.getIdentifier("alpha"),
rewriter.getFloatAttr(mlir::Builder(rewriter.getContext()).getF32Type(), 1.0));
attrs.emplace_back(
rewriter.getIdentifier("beta"),
rewriter.getFloatAttr(mlir::Builder(rewriter.getContext()).getF32Type(), 1.0));
attrs.emplace_back(rewriter.getIdentifier("transA"), rewriter.getBoolAttr(false));
attrs.emplace_back(rewriter.getIdentifier("transB"), rewriter.getBoolAttr(false));
SmallVector<Type, 4> types;
for (auto v : castedOp0.getODSResults(0))
{
types.push_back(v.getType());
}
return rewriter.create<NGGemmOp>(castedOp0.getLoc(), types, values, attrs);
}
#include "fused_ops_pattern.h.inc"
}
namespace
{
class NgDialectFusedOpsPass : public mlir::ModulePass<NgDialectFusedOpsPass>
{
public:
NgDialectFusedOpsPass() {}
private:
void runOnModule() override;
};
}
void NgDialectFusedOpsPass::runOnModule()
{
OwningRewritePatternList patterns;
mlir::populateWithGenerated(&getContext(), &patterns);
// Gather functions to be processed. Note that new functions will be added to module as part
// of the function signature conversion so we have to collect the original ones before hand.
SmallVector<FuncOp, 2> origFuncOps(getModule().getOps<FuncOp>());
for (auto origFunc : origFuncOps)
{
applyPatternsGreedily(origFunc, patterns);
}
}
std::unique_ptr<Pass> ngraph::pass::createNgDialectFusedOpsPass()
{
return std::make_unique<NgDialectFusedOpsPass>();
}
static PassRegistration<NgDialectFusedOpsPass>
pass(PASS_NAME, "Fuse ngraph dialct based on the pattern match");
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
// NOTE: This file follows nGraph format style.
// Follows nGraph naming convention for public APIs only, else MLIR naming convention.
#pragma once
#include <mlir/Pass/Pass.h>
namespace ngraph
{
namespace pass
{
std::unique_ptr<mlir::Pass> createNgDialectFusedOpsPass();
}
}
......@@ -4025,4 +4025,39 @@ TEST(cpu_fusion, validate_fuse_gru_inputs)
EXPECT_TRUE(test::all_close(cpu_results.at(i), int_results.at(i), 1.0e-4f, 1.0e-4f));
}
}
TEST(cpu_fusion, mlir_matmul_bias)
{
Shape shape{};
Shape shape_w{2, 4};
Shape shape_x{4, 1};
Shape shape_b{1};
auto A = make_shared<op::Parameter>(element::f32, shape_w);
auto B = make_shared<op::Parameter>(element::f32, shape_x);
auto C = make_shared<op::Parameter>(element::f32, shape_b);
auto dot = make_shared<op::Dot>(A, B);
auto broadcast = make_shared<op::Broadcast>(C, dot->get_shape(), AxisSet{0});
auto add = dot + broadcast;
auto int_func = make_shared<Function>(NodeVector{add}, ParameterVector{A, B, C});
auto cpu_func = make_shared<Function>(NodeVector{add}, ParameterVector{A, B, C});
test::Uniform<float> rng(-10.0f, 10.0f);
vector<vector<float>> args;
for (shared_ptr<op::Parameter> param : int_func->get_parameters())
{
vector<float> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
auto int_results = execute(cpu_func, args, "INTERPRETER");
auto cpu_results = execute(cpu_func, args, "CPU");
for (size_t i = 0; i < cpu_results.size(); i++)
{
EXPECT_TRUE(test::all_close(cpu_results.at(i), int_results.at(i), 1.0e-4f, 1.0e-4f));
}
}
#endif
// RUN: ngraph-opt %s -fuse-ngraph-dialect -split-input-file | FileCheck %s
// Verify that operations fused using pattern matcher are properly replaced with correct Fused Op.
// -----
// matmul+bias
// CHECK-LABEL: func @matmul_bias_fusion(%arg0: !ng.tensor<2x4xf32>, %arg1: !ng.tensor<4x1xf32>, %arg2: !ng.tensor<2x1xf32>) -> !ng.tensor<2x1xf32> {
// CHECK: %0 = "ng.gemm"(%arg0, %arg1, %arg2) {alpha = {{.*}}: f32, beta = {{.*}} : f32, transA = {{.*}}, transB = {{.*}}} : (!ng.tensor<2x4xf32>, !ng.tensor<4x1xf32>, !ng.tensor<2x1xf32>) -> !ng.tensor<2x1xf32>
// CHECK: "ng.return"(%0) : (!ng.tensor<2x1xf32>) -> ()
func @matmul_bias_fusion(%arg0: !ng.tensor<2x4xf32>, %arg1: !ng.tensor<4x1xf32>, %arg2: !ng.tensor<2x1xf32>) -> !ng.tensor<2x1xf32> {
%0 = "ng.dot"(%arg0, %arg1) : (!ng.tensor<2x4xf32>, !ng.tensor<4x1xf32>) -> !ng.tensor<2x1xf32>
%1 = "ng.add"(%0, %arg2) : (!ng.tensor<2x1xf32>, !ng.tensor<2x1xf32>) -> !ng.tensor<2x1xf32>
"ng.return"(%1) : (!ng.tensor<2x1xf32>) -> ()
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment