Commit c737a573 authored by Amy Zhuang's avatar Amy Zhuang Committed by Scott Cyphers

[MLIR] Use call back for MatMul. (#3838)

* [MLIR] Use call back for MatMul.

* Use callback for Gemm.

* Use mkldnn callback for Softmax.

* Address PR feedback.

* Fix merge errors.

* Change to tail allocation struct.

* Use mkldnn callback for AvgPool.

* Add callbacks for AvgPoolBackprop, MaxPool, and MaxPoolBackprop.

* Fix merge errors.

* Use UnrankedMemRefType for callbacks.

* Address PR feedback.

* Cleanup.

* Address PR feedback.

* Fix a bug.

* Use global variable to hold attributes.

* Convert layout if needed for pooling.

* Address PR feedback.

* Add header.

* Address PR feedback.

* Update Copyright to 2017-2020.

* Address PR feedback.
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent e8c0282c
...@@ -36,6 +36,7 @@ set(SRC ...@@ -36,6 +36,7 @@ set(SRC
core/pass/ng_dialect_builder.hpp core/pass/ng_dialect_builder.hpp
runtime/cpu/memory_manager.cpp runtime/cpu/memory_manager.cpp
runtime/cpu/cpu_runtime.cpp runtime/cpu/cpu_runtime.cpp
runtime/cpu/cpu_callbacks.cpp
utils.cpp utils.cpp
) )
...@@ -90,7 +91,8 @@ target_link_libraries( ...@@ -90,7 +91,8 @@ target_link_libraries(
) )
# Link ngraph # Link ngraph
target_link_libraries(mlir_backend PUBLIC ngraph) target_link_libraries(mlir_backend PUBLIC ngraph libmkl libmkldnn)
target_include_directories(mlir_backend SYSTEM PUBLIC libmkldnn)
# table-gen dialect ops # table-gen dialect ops
# include table-gen helpers # include table-gen helpers
......
...@@ -33,6 +33,7 @@ ...@@ -33,6 +33,7 @@
#include <mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h> #include <mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h>
#include <mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h> #include <mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h>
#include <mlir/Dialect/LLVMIR/LLVMDialect.h> #include <mlir/Dialect/LLVMIR/LLVMDialect.h>
#include <mlir/IR/StandardTypes.h>
#include <mlir/Pass/PassManager.h> #include <mlir/Pass/PassManager.h>
#include <mlir/Target/LLVMIR.h> #include <mlir/Target/LLVMIR.h>
#include <mlir/Transforms/DialectConversion.h> #include <mlir/Transforms/DialectConversion.h>
......
...@@ -27,22 +27,29 @@ ...@@ -27,22 +27,29 @@
MLIR_OP(NGAddOp , true ) MLIR_OP(NGAddOp , true )
MLIR_OP(NGArgMaxRedOp , false ) MLIR_OP(NGArgMaxRedOp , false )
MLIR_OP(NGArgMinRedOp , false ) MLIR_OP(NGArgMinRedOp , false )
MLIR_OP(NGAvgPoolOp , false )
MLIR_OP(NGAvgPoolBackpropOp , false )
MLIR_OP(NGConcatOp , true ) MLIR_OP(NGConcatOp , true )
MLIR_OP(NGConvolutionOp , false ) MLIR_OP(NGConvolutionOp , false )
MLIR_OP(NGDivOp , true ) MLIR_OP(NGDivOp , true )
MLIR_OP(NGDotOp , false ) MLIR_OP(NGDotOp , false )
MLIR_OP(NGGatherOp , false ) MLIR_OP(NGGatherOp , false )
MLIR_OP(NGGemmOp , false )
MLIR_OP(NGGreaterOp , true ) MLIR_OP(NGGreaterOp , true )
MLIR_OP(NGLessOp , true ) MLIR_OP(NGLessOp , true )
MLIR_OP(NGGreaterEqOp , true ) MLIR_OP(NGGreaterEqOp , true )
MLIR_OP(NGLessEqOp , true ) MLIR_OP(NGLessEqOp , true )
MLIR_OP(NGEqOp , true ) MLIR_OP(NGEqOp , true )
MLIR_OP(NGNotEqOp , true ) MLIR_OP(NGNotEqOp , true )
MLIR_OP(NGMatMulOp , false )
MLIR_OP(NGMulOp , true ) MLIR_OP(NGMulOp , true )
MLIR_OP(NGMaxOp , true ) MLIR_OP(NGMaxOp , true )
MLIR_OP(NGMaxPoolOp , false )
MLIR_OP(NGMaxPoolBackpropOp , false )
MLIR_OP(NGMinOp , true ) MLIR_OP(NGMinOp , true )
MLIR_OP(NGNegOp , true ) MLIR_OP(NGNegOp , true )
MLIR_OP(NGReluOp , true ) MLIR_OP(NGReluOp , true )
MLIR_OP(NGSoftMaxOp , false )
MLIR_OP(NGSubOp , true ) MLIR_OP(NGSubOp , true )
MLIR_LAST_OP(NGReturnOp , false ) MLIR_LAST_OP(NGReturnOp , false )
......
...@@ -28,24 +28,7 @@ ...@@ -28,24 +28,7 @@
#include "ngraph/descriptor/tensor.hpp" #include "ngraph/descriptor/tensor.hpp"
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/add.hpp" #include "ngraph/ops.hpp"
#include "ngraph/op/argmax.hpp"
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/op/gather.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/less.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/negative.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/util/index_reduction.hpp"
#include "ngraph/type/element_type.hpp" #include "ngraph/type/element_type.hpp"
#include "contrib/mlir/utils.hpp" #include "contrib/mlir/utils.hpp"
......
...@@ -282,7 +282,7 @@ def NGMVN : ...@@ -282,7 +282,7 @@ def NGMVN :
} }
// MatMul Op // MatMul Op
def NGMatMul : def NGMatMulOp :
NG_OneResult_Op<"matmul", [NoSideEffect, DeclareOpInterfaceMethods<FusedOp>]>, NG_OneResult_Op<"matmul", [NoSideEffect, DeclareOpInterfaceMethods<FusedOp>]>,
Arguments<(ins NG_TensorType:$A, NG_TensorType:$B, Arguments<(ins NG_TensorType:$A, NG_TensorType:$B,
DefaultValuedAttr<BoolAttr, "false">:$transposeA, DefaultValuedAttr<BoolAttr, "false">:$transposeA,
......
...@@ -309,6 +309,55 @@ mlir::LogicalResult verifyOp(NGConvolutionOp* op) ...@@ -309,6 +309,55 @@ mlir::LogicalResult verifyOp(NGConvolutionOp* op)
return mlir::success(); return mlir::success();
} }
template <>
mlir::LogicalResult verifyOp(NGMatMulOp* op)
{
// TODO(ayzhuang): Improve verification: proper shapes, etc.
return mlir::success();
}
template <>
mlir::LogicalResult verifyOp(NGGemmOp* op)
{
// TODO(ayzhuang): Improve verification: proper shapes, etc.
return mlir::success();
}
template <>
mlir::LogicalResult verifyOp(NGSoftMaxOp* op)
{
// TODO(ayzhuang): Improve verification: proper shapes, etc.
return mlir::success();
}
template <>
mlir::LogicalResult verifyOp(NGAvgPoolOp* op)
{
// TODO(ayzhuang): Improve verification: proper shapes, etc.
return mlir::success();
}
template <>
mlir::LogicalResult verifyOp(NGAvgPoolBackpropOp* op)
{
// TODO(ayzhuang): Improve verification: proper shapes, etc.
return mlir::success();
}
template <>
mlir::LogicalResult verifyOp(NGMaxPoolOp* op)
{
// TODO(ayzhuang): Improve verification: proper shapes, etc.
return mlir::success();
}
template <>
mlir::LogicalResult verifyOp(NGMaxPoolBackpropOp* op)
{
// TODO(ayzhuang): Improve verification: proper shapes, etc.
return mlir::success();
}
namespace mlir namespace mlir
{ {
#include "ops_interfaces.cpp.inc" #include "ops_interfaces.cpp.inc"
...@@ -401,7 +450,7 @@ void mlir::NGLSTMCellOp::decompose() ...@@ -401,7 +450,7 @@ void mlir::NGLSTMCellOp::decompose()
void mlir::NGLSTMSequenceOp::decompose() void mlir::NGLSTMSequenceOp::decompose()
{ {
} }
void mlir::NGMatMul::decompose() void mlir::NGMatMulOp::decompose()
{ {
} }
void mlir::NGLayerNormOp::decompose() void mlir::NGLayerNormOp::decompose()
......
...@@ -252,8 +252,8 @@ def NGAvgPoolOp : ...@@ -252,8 +252,8 @@ def NGAvgPoolOp :
} }
// AvgPool for back prop // AvgPool for back prop
def NGAvgPoolBackPropOp : def NGAvgPoolBackpropOp :
NG_OneResult_Op<"avgPoolBackProp", [NoSideEffect, OpVersion0]>, NG_OneResult_Op<"avgPoolBackprop", [NoSideEffect, OpVersion0]>,
Arguments<(ins I64ArrayAttr :$forwardArgShape, Arguments<(ins I64ArrayAttr :$forwardArgShape,
NG_TensorType :$delta, NG_TensorType :$delta,
I64ArrayAttr :$windowShape, I64ArrayAttr :$windowShape,
...@@ -455,11 +455,10 @@ def NGMaxPoolOp : ...@@ -455,11 +455,10 @@ def NGMaxPoolOp :
} }
// MaxPool for back prop // MaxPool for back prop
def NGMaxPoolBackPropOp : def NGMaxPoolBackpropOp :
NG_OneResult_Op<"maxPoolBackProp", [NoSideEffect, OpVersion0]>, NG_OneResult_Op<"maxPoolBackprop", [NoSideEffect, OpVersion0]>,
Arguments<(ins NG_TensorType :$argForward, Arguments<(ins NG_TensorType :$argForward,
NG_TensorType :$delta, NG_TensorType :$delta,
NG_TensorType :$resultForward,
I64ArrayAttr :$windowShape, I64ArrayAttr :$windowShape,
I64ArrayAttr :$windowMovementStrides, I64ArrayAttr :$windowMovementStrides,
I64ArrayAttr :$padBelow, I64ArrayAttr :$padBelow,
...@@ -473,23 +472,6 @@ def NGMaxPoolBackPropOp : ...@@ -473,23 +472,6 @@ def NGMaxPoolBackPropOp :
let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }]; let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
let verifier = [{ return verifyOp(this); }]; let verifier = [{ return verifyOp(this); }];
let builders = [
// Builder without resultForward
OpBuilder<
"Builder *builder, OperationState &tblgen_state, Type res, "
"Value *argForward, Value *delta, "
"ArrayAttr windowShape, ArrayAttr windowMovementStrides, "
"ArrayAttr padBelow, ArrayAttr padAbove", [{
tblgen_state.addOperands(argForward);
tblgen_state.addOperands(delta);
tblgen_state.addOperands(nullptr);
tblgen_state.addAttribute("windowShape", windowShape);
tblgen_state.addAttribute("windowMovementStrides", windowMovementStrides);
tblgen_state.addAttribute("padBelow", padBelow);
tblgen_state.addAttribute("padAbove", padAbove);
tblgen_state.addTypes(res);
}]>
];
let extraClassDeclaration = [{ let extraClassDeclaration = [{
void setWindowShape(const ArrayAttr& arrayAttr) { this->setAttr("windowShape", arrayAttr); } void setWindowShape(const ArrayAttr& arrayAttr) { this->setAttr("windowShape", arrayAttr); }
......
...@@ -6,23 +6,31 @@ ...@@ -6,23 +6,31 @@
MLIR_OP(Add) MLIR_OP(Add)
MLIR_OP(ArgMin) MLIR_OP(ArgMin)
MLIR_OP(ArgMax) MLIR_OP(ArgMax)
MLIR_OP(AvgPool)
MLIR_OP(AvgPoolBackprop)
MLIR_OP(Divide) MLIR_OP(Divide)
MLIR_OP(Dot) MLIR_OP(Dot)
MLIR_OP(Concat) MLIR_OP(Concat)
MLIR_OP(Convolution) MLIR_OP(Convolution)
MLIR_OP(Gather) MLIR_OP(Gather)
MLIR_OP(Gemm)
MLIR_OP(Greater) MLIR_OP(Greater)
MLIR_OP(Less) MLIR_OP(Less)
MLIR_OP(GreaterEq) MLIR_OP(GreaterEq)
MLIR_OP(LessEq) MLIR_OP(LessEq)
MLIR_OP(Equal) MLIR_OP(Equal)
MLIR_OP(NotEqual) MLIR_OP(NotEqual)
MLIR_OP(MatMul)
MLIR_OP(Maximum) MLIR_OP(Maximum)
MLIR_OP(MaxPool)
MLIR_OP(MaxPoolBackprop)
MLIR_OP(Minimum) MLIR_OP(Minimum)
MLIR_OP(Multiply) MLIR_OP(Multiply)
MLIR_OP(Negative) MLIR_OP(Negative)
MLIR_OP(Softmax)
MLIR_OP(Subtract) MLIR_OP(Subtract)
MLIR_OP(Relu) MLIR_OP(Relu)
// Add new supported ops here // Add new supported ops here
#undef MLIR_OP #undef MLIR_OP
...@@ -21,28 +21,7 @@ ...@@ -21,28 +21,7 @@
#include "ngraph/assertion.hpp" #include "ngraph/assertion.hpp"
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/add.hpp" #include "ngraph/ops.hpp"
#include "ngraph/op/argmax.hpp"
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/equal.hpp"
#include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/op/gather.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/greater_eq.hpp"
#include "ngraph/op/less.hpp"
#include "ngraph/op/less_eq.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/negative.hpp"
#include "ngraph/op/not_equal.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/subtract.hpp"
using namespace ngraph::descriptor; using namespace ngraph::descriptor;
using namespace ngraph::op; using namespace ngraph::op;
...@@ -498,6 +477,104 @@ bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node ...@@ -498,6 +477,104 @@ bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node
std::all_of(window_dilation.begin(), window_dilation.end(), is_one); std::all_of(window_dilation.begin(), window_dilation.end(), is_one);
} }
// MKLDNN only supports softmax across single axis
if (TI(ngraph::op::Softmax) == TI(*node))
{
// Softmax is only supported through callback
if (std::getenv("NGRAPH_MLIR_CALLBACK") == nullptr)
{
return false;
}
auto softmax = static_cast<ngraph::op::Softmax*>(node.get());
auto arg0_shape = node->get_input_shape(0);
auto arg0_rank = arg0_shape.size();
return (arg0_rank == 4 || arg0_rank == 2) &&
node->get_input_element_type(0) == element::f32 && softmax->get_axes().size() == 1;
}
if (TI(ngraph::op::AvgPool) == TI(*node))
{
// AvgPool is only supported through callback
if (std::getenv("NGRAPH_MLIR_CALLBACK") == nullptr)
{
return false;
}
auto avg_pool = static_cast<ngraph::op::AvgPool*>(node.get());
auto arg0_shape = node->get_input_shape(0);
auto arg0_rank = arg0_shape.size();
return ((arg0_rank == 4 && avg_pool->get_window_shape().size() == 2) ||
(arg0_rank == 5 && avg_pool->get_window_shape().size() == 3)) &&
node->get_input_element_type(0) == element::f32;
}
if (TI(ngraph::op::AvgPoolBackprop) == TI(*node))
{
// AvgPoolBackprop is only supported through callback
if (std::getenv("NGRAPH_MLIR_CALLBACK") == nullptr)
{
return false;
}
auto avg_pool_backprop = static_cast<ngraph::op::AvgPoolBackprop*>(node.get());
auto arg0_shape = node->get_input_shape(0);
auto arg0_rank = arg0_shape.size();
return ((arg0_rank == 4 && avg_pool_backprop->get_window_shape().size() == 2) ||
(arg0_rank == 5 && avg_pool_backprop->get_window_shape().size() == 3)) &&
node->get_input_element_type(0) == element::f32;
}
if (TI(ngraph::op::MaxPoolBackprop) == TI(*node))
{
// MaxPoolBackprop is only supported through callback
if (std::getenv("NGRAPH_MLIR_CALLBACK") == nullptr)
{
return false;
}
auto max_pool_backprop = static_cast<ngraph::op::MaxPoolBackprop*>(node.get());
auto arg0_shape = node->get_input_shape(0);
auto arg0_rank = arg0_shape.size();
return ((arg0_rank == 4 && max_pool_backprop->get_window_shape().size() == 2) ||
(arg0_rank == 5 && max_pool_backprop->get_window_shape().size() == 3)) &&
node->get_input_element_type(0) == element::f32;
}
if (TI(ngraph::op::MaxPool) == TI(*node))
{
// MaxPool is only supported through callback
if (std::getenv("NGRAPH_MLIR_CALLBACK") == nullptr)
{
return false;
}
auto max_pool = static_cast<ngraph::op::MaxPool*>(node.get());
auto arg0_shape = node->get_input_shape(0);
auto arg0_rank = arg0_shape.size();
return ((arg0_rank == 4 && max_pool->get_window_shape().size() == 2) ||
(arg0_rank == 5 && max_pool->get_window_shape().size() == 3)) &&
node->get_input_element_type(0) == element::f32;
}
if (TI(ngraph::op::MatMul) == TI(*node))
{
// MatMul is only supported through callback
if (std::getenv("NGRAPH_MLIR_CALLBACK") == nullptr)
{
return false;
}
}
if (TI(ngraph::op::Gemm) == TI(*node))
{
// Gemm is only supported through callback
if (std::getenv("NGRAPH_MLIR_CALLBACK") == nullptr)
{
return false;
}
}
return true; return true;
} }
......
...@@ -26,28 +26,7 @@ ...@@ -26,28 +26,7 @@
#include "ngraph/descriptor/tensor.hpp" #include "ngraph/descriptor/tensor.hpp"
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/add.hpp" #include "ngraph/ops.hpp"
#include "ngraph/op/argmax.hpp"
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/equal.hpp"
#include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/op/gather.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/greater_eq.hpp"
#include "ngraph/op/less.hpp"
#include "ngraph/op/less_eq.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/negative.hpp"
#include "ngraph/op/not_equal.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/util/index_reduction.hpp"
#include "ngraph/type/element_type.hpp" #include "ngraph/type/element_type.hpp"
// Defines a new LLVM debug type for this file to be used by LLVM_DEBUG macro. // Defines a new LLVM debug type for this file to be used by LLVM_DEBUG macro.
...@@ -117,8 +96,9 @@ namespace ...@@ -117,8 +96,9 @@ namespace
// Generic op lowerer to ng dialect. // Generic op lowerer to ng dialect.
// Simply maps ngraph tensors to values and generate an OP. No op-specific logic. // Simply maps ngraph tensors to values and generate an OP. No op-specific logic.
// Use inNum when mlir OP needs less input than its corresponding ngraph OP.
template <typename Op> template <typename Op>
mlir::Operation* createGenericOp(const ngraph::Node* ngNode); mlir::Operation* createGenericOp(const ngraph::Node* ngNode, int inNum = -1);
template <typename RedOp> template <typename RedOp>
mlir::Operation* createIndexReduction(const ngraph::Node* ngNode); mlir::Operation* createIndexReduction(const ngraph::Node* ngNode);
...@@ -133,6 +113,9 @@ namespace ...@@ -133,6 +113,9 @@ namespace
template <typename T> template <typename T>
mlir::ArrayAttr getShapeAsAttr(T ngShape); mlir::ArrayAttr getShapeAsAttr(T ngShape);
/// Return the real input node corresponding to the fake node
ngraph::Node* getOriginArg(ngraph::Node* node) const;
private: private:
// Sub-graph to be compiled and executed with MLIR. // Sub-graph to be compiled and executed with MLIR.
const ngraph::op::CompiledKernel* m_compiledKernel; const ngraph::op::CompiledKernel* m_compiledKernel;
...@@ -220,6 +203,14 @@ mlir::ArrayAttr NgDialectConversionPass::getShapeAsAttr(T ngShape) ...@@ -220,6 +203,14 @@ mlir::ArrayAttr NgDialectConversionPass::getShapeAsAttr(T ngShape)
return m_builder.getI64ArrayAttr(mlirShape); return m_builder.getI64ArrayAttr(mlirShape);
} }
ngraph::Node* NgDialectConversionPass::getOriginArg(ngraph::Node* node) const
{
auto inputMap = m_compiledKernel->get_input_map();
auto it = inputMap.find(node->shared_from_this());
NGRAPH_CHECK(it != inputMap.end(), "Parameter not in CK input map");
return m_compiledKernel->input_values().at(it->second).get_node();
}
// Converts an nGraph Tensor into an MLIR tensor type, including the conversion of the Tensor's // Converts an nGraph Tensor into an MLIR tensor type, including the conversion of the Tensor's
// element type. // element type.
mlir::Type NgDialectConversionPass::getMlirType(const descriptor::Tensor* tensor) mlir::Type NgDialectConversionPass::getMlirType(const descriptor::Tensor* tensor)
...@@ -464,17 +455,157 @@ mlir::Operation* NgDialectConversionPass::COMPILE_OP_DECL(ngraph::op::Convolutio ...@@ -464,17 +455,157 @@ mlir::Operation* NgDialectConversionPass::COMPILE_OP_DECL(ngraph::op::Convolutio
return op; return op;
} }
template <>
mlir::Operation* NgDialectConversionPass::COMPILE_OP_DECL(ngraph::op::AvgPool)
{
mlir::Operation* op = NgDialectObj.createGenericOp<mlir::NGAvgPoolOp>(ngNode);
auto avgPoolNode = static_cast<const ngraph::op::AvgPool*>(ngNode);
auto avgPoolOp = llvm::cast<mlir::NGAvgPoolOp>(op);
mlir::BoolAttr boolAttr =
NgDialectObj.m_builder.getBoolAttr(avgPoolNode->get_include_padding_in_avg_computation());
avgPoolOp.setIncludePadding(boolAttr);
mlir::ArrayAttr attr = NgDialectObj.getShapeAsAttr(avgPoolNode->get_window_shape());
avgPoolOp.setWindowShape(attr);
attr = NgDialectObj.getShapeAsAttr(avgPoolNode->get_window_movement_strides());
avgPoolOp.setWindowMovementStrides(attr);
attr = NgDialectObj.getShapeAsAttr(avgPoolNode->get_padding_below());
avgPoolOp.setPadBelow(attr);
attr = NgDialectObj.getShapeAsAttr(avgPoolNode->get_padding_above());
avgPoolOp.setPadAbove(attr);
return op;
}
template <>
mlir::Operation* NgDialectConversionPass::COMPILE_OP_DECL(ngraph::op::AvgPoolBackprop)
{
mlir::Operation* op = NgDialectObj.createGenericOp<mlir::NGAvgPoolBackpropOp>(ngNode);
auto avgPoolBackpropNode = static_cast<const ngraph::op::AvgPoolBackprop*>(ngNode);
auto avgPoolBackpropOp = llvm::cast<mlir::NGAvgPoolBackpropOp>(op);
mlir::BoolAttr boolAttr = NgDialectObj.m_builder.getBoolAttr(
avgPoolBackpropNode->get_include_padding_in_avg_computation());
avgPoolBackpropOp.setIncludePadding(boolAttr);
mlir::ArrayAttr attr = NgDialectObj.getShapeAsAttr(avgPoolBackpropNode->get_window_shape());
avgPoolBackpropOp.setWindowShape(attr);
attr = NgDialectObj.getShapeAsAttr(avgPoolBackpropNode->get_window_movement_strides());
avgPoolBackpropOp.setWindowMovementStrides(attr);
attr = NgDialectObj.getShapeAsAttr(avgPoolBackpropNode->get_padding_below());
avgPoolBackpropOp.setPadBelow(attr);
attr = NgDialectObj.getShapeAsAttr(avgPoolBackpropNode->get_padding_above());
avgPoolBackpropOp.setPadAbove(attr);
attr = NgDialectObj.getShapeAsAttr(avgPoolBackpropNode->get_forward_arg_shape());
avgPoolBackpropOp.setForwardArgShape(attr);
return op;
}
template <>
mlir::Operation* NgDialectConversionPass::COMPILE_OP_DECL(ngraph::op::MaxPool)
{
mlir::Operation* op = NgDialectObj.createGenericOp<mlir::NGMaxPoolOp>(ngNode);
auto maxPoolNode = static_cast<const ngraph::op::MaxPool*>(ngNode);
auto maxPoolOp = llvm::cast<mlir::NGMaxPoolOp>(op);
mlir::ArrayAttr attr = NgDialectObj.getShapeAsAttr(maxPoolNode->get_window_shape());
maxPoolOp.setWindowShape(attr);
attr = NgDialectObj.getShapeAsAttr(maxPoolNode->get_window_movement_strides());
maxPoolOp.setWindowMovementStrides(attr);
attr = NgDialectObj.getShapeAsAttr(maxPoolNode->get_padding_below());
maxPoolOp.setPadBelow(attr);
attr = NgDialectObj.getShapeAsAttr(maxPoolNode->get_padding_above());
maxPoolOp.setPadAbove(attr);
return op;
}
template <>
mlir::Operation* NgDialectConversionPass::COMPILE_OP_DECL(ngraph::op::MaxPoolBackprop)
{
mlir::Operation* op = NgDialectObj.createGenericOp<mlir::NGMaxPoolBackpropOp>(ngNode, 2);
auto maxPoolBackpropNode = static_cast<const ngraph::op::MaxPool*>(ngNode);
auto maxPoolBackpropOp = llvm::cast<mlir::NGMaxPoolBackpropOp>(op);
mlir::ArrayAttr attr = NgDialectObj.getShapeAsAttr(maxPoolBackpropNode->get_window_shape());
maxPoolBackpropOp.setWindowShape(attr);
attr = NgDialectObj.getShapeAsAttr(maxPoolBackpropNode->get_window_movement_strides());
maxPoolBackpropOp.setWindowMovementStrides(attr);
attr = NgDialectObj.getShapeAsAttr(maxPoolBackpropNode->get_padding_below());
maxPoolBackpropOp.setPadBelow(attr);
attr = NgDialectObj.getShapeAsAttr(maxPoolBackpropNode->get_padding_above());
maxPoolBackpropOp.setPadAbove(attr);
return op;
}
template <>
mlir::Operation* NgDialectConversionPass::COMPILE_OP_DECL(ngraph::op::MatMul)
{
auto matmulNode = static_cast<const ngraph::op::MatMul*>(ngNode);
auto op = NgDialectObj.createGenericOp<mlir::NGMatMulOp>(ngNode);
auto matmulOp = llvm::cast<mlir::NGMatMulOp>(op);
matmulOp.setTransposeA(NgDialectObj.m_builder.getBoolAttr(matmulNode->get_transpose_a()));
matmulOp.setTransposeB(NgDialectObj.m_builder.getBoolAttr(matmulNode->get_transpose_b()));
return op;
}
template <>
mlir::Operation* NgDialectConversionPass::COMPILE_OP_DECL(ngraph::op::Gemm)
{
auto gemmNode = static_cast<const ngraph::op::Gemm*>(ngNode);
auto op = NgDialectObj.createGenericOp<mlir::NGGemmOp>(ngNode);
auto gemmOp = llvm::cast<mlir::NGGemmOp>(op);
gemmOp.setTransA(NgDialectObj.m_builder.getBoolAttr(gemmNode->get_transA()));
gemmOp.setTransB(NgDialectObj.m_builder.getBoolAttr(gemmNode->get_transB()));
gemmOp.setAlpha(NgDialectObj.m_builder.getF32FloatAttr(gemmNode->get_alpha()));
gemmOp.setBeta(NgDialectObj.m_builder.getF32FloatAttr(gemmNode->get_beta()));
return op;
}
template <>
mlir::Operation* NgDialectConversionPass::COMPILE_OP_DECL(ngraph::op::Softmax)
{
mlir::Operation* op = NgDialectObj.createGenericOp<mlir::NGSoftMaxOp>(ngNode, 1);
auto softmaxNode = static_cast<const ngraph::op::Softmax*>(ngNode);
auto softmaxOp = llvm::cast<mlir::NGSoftMaxOp>(op);
auto originArg = NgDialectObj.getOriginArg(ngNode->input_value(1).get_node());
auto const_op = static_cast<ngraph::op::Constant*>(originArg);
AxisSet axes = const_op->get_axis_set_val();
mlir::ArrayAttr attr = NgDialectObj.getShapeAsAttr(axes);
softmaxOp.setAxes(attr);
return op;
}
template <typename Op> template <typename Op>
mlir::Operation* NgDialectConversionPass::createGenericOp(const ngraph::Node* ngNode) mlir::Operation* NgDialectConversionPass::createGenericOp(const ngraph::Node* ngNode, int inNum)
{ {
std::vector<mlir::Value*> argValues; std::vector<mlir::Value*> argValues;
std::vector<mlir::Type> resTypes; std::vector<mlir::Type> resTypes;
auto inputMap = m_compiledKernel->get_input_map(); auto inputMap = m_compiledKernel->get_input_map();
std::shared_ptr<descriptor::Tensor> argTensor; std::shared_ptr<descriptor::Tensor> argTensor;
int i = 0;
for (auto& argOutput : ngNode->input_values()) for (auto& argOutput : ngNode->input_values())
{ {
if (inNum != -1 && i == inNum)
{
break;
}
auto argOutputNode = argOutput.get_node(); auto argOutputNode = argOutput.get_node();
if (as_type<op::Parameter>(argOutputNode)) if (is_type<op::Parameter>(argOutputNode))
{ {
auto it = inputMap.find(argOutputNode->shared_from_this()); auto it = inputMap.find(argOutputNode->shared_from_this());
NGRAPH_CHECK(it != inputMap.end(), "Parameter not in CK input map"); NGRAPH_CHECK(it != inputMap.end(), "Parameter not in CK input map");
...@@ -488,6 +619,7 @@ mlir::Operation* NgDialectConversionPass::createGenericOp(const ngraph::Node* ng ...@@ -488,6 +619,7 @@ mlir::Operation* NgDialectConversionPass::createGenericOp(const ngraph::Node* ng
auto argV = getTensorValue(argTensor.get()).m_value; auto argV = getTensorValue(argTensor.get()).m_value;
argValues.push_back(argV); argValues.push_back(argV);
i++;
} }
for (auto& output : ngNode->outputs()) for (auto& output : ngNode->outputs())
......
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <cstdint>
#pragma once
namespace ngraph
{
namespace runtime
{
namespace ngmlir
{
// OpType class is used for callbacks.
// We pass OpType to the generic callback functions,
// which call the real implementation based on OpType.
// TODO remove those not needed once all callbacks are implemented.
enum class OpType
{
ADD = 0,
AVGPOOL,
AVGPOOLBACKPROP,
BATCHNORM3ARGS,
BATCHNORM5ARGS,
BATCHNORMBACKPROP,
BOUNDEDRELU,
CONCAT,
CONVERTLAYOUT,
CONVOLUTION,
CONVOLUTIONRELU,
CONVOLUTIONADD,
CONVOLUTIONBIAS,
CONVOLUTIONBIASADD,
CONVOLUTIONBACKPROPDATA,
CONVOLUTIONBACKPROPWEIGHTS,
CONVOLUTIONBACKPROPWEIGHTSBIAS,
GELU,
GELUBACKPROP,
GEMM,
GROUPCONVOLUTION,
GROUPCONVOLUTIONBIAS,
DECONVOLUTIONBIAS,
LEAKYRELU,
LRN,
LSTM,
MATMUL,
MAXPOOL,
MAXPOOLBACKPROP,
MAXPOOLBACKPROPFORWARD,
MAXPOOLBACKPROPBACKWARD,
MAXPOOLWITHINDICES,
MAXPOOLWITHINDICESBACKPROP,
QUANTIZE,
DEQUANTIZE,
QUANTIZEDAVGPOOL,
QUANTIZEDMAXPOOL,
QUANTIZEDCONCAT,
QUANTIZEDDOTBIAS,
QUANTIZEDMATMUL,
QUANTIZEDCONVOLUTION,
QUANTIZEDCONVOLUTIONBIAS,
QUANTIZEDCONVOLUTIONBIASADD,
QUANTIZEDCONVOLUTIONBIASSIGNEDADD,
QUANTIZEDCONVOLUTIONRELU,
RELU,
RELUBACKPROP,
RNN,
SIGMOID,
SIGMOIDBACKPROP,
SLICE,
SOFTMAX
};
// These structs and union are used to pass attributes to callbacks.
template <int N>
struct poolAttrs
{
bool includePaddingInAvgComputation;
int64_t windowShape[N];
int64_t windowStrides[N];
int64_t padBelow[N];
int64_t padAbove[N];
};
struct gemmAttrs
{
bool transposeA;
bool transposeB;
int64_t m;
int64_t n;
int64_t k;
int64_t lda;
int64_t ldb;
int64_t ldc;
float alpha;
float beta;
int64_t broadcastHint;
};
union opAttrs {
int intAttr;
poolAttrs<2> poolAttrs2d;
poolAttrs<3> poolAttrs3d;
gemmAttrs gemmAttrs2d;
};
} // namespace ngmlir
} // namespace runtime
} // namespace ngraph
This diff is collapsed.
...@@ -53,16 +53,18 @@ static llvm::cl::opt<std::string> ...@@ -53,16 +53,18 @@ static llvm::cl::opt<std::string>
clObjectFilename("ngraph-mlir-object-filename", clObjectFilename("ngraph-mlir-object-filename",
llvm::cl::desc("Dump MLIR JITted-compiled object to file jitted_mlir.o")); llvm::cl::desc("Dump MLIR JITted-compiled object to file jitted_mlir.o"));
void MLIRCPURuntime::run(void* args) void MLIRCPURuntime::run(const std::vector<MemRefArg>& args)
{ {
run_internal(*reinterpret_cast<std::vector<void*>*>(args)); // run_internal(*reinterpret_cast<std::vector<void*>*>(args), shapeVec, stridesVec);
run_internal(args);
} }
void MLIRCPURuntime::run_internal(std::vector<void*>& externalTensors) void MLIRCPURuntime::run_internal(const std::vector<MemRefArg>& args)
{ {
// Create an MLIR execution engine. We use a null MLIR pass manager for now to make sure we // Create an MLIR execution engine. We use a null MLIR pass manager for now to make sure we
// don't run MLIR passes that were already run. We also pass a default transformer created with // don't run MLIR passes that were already run. We also pass a default transformer created with
// the default or user-provided optimization level. // the default or user-provided optimization level.
auto llvmTransformer = mlir::makeOptimizingTransformer( auto llvmTransformer = mlir::makeOptimizingTransformer(
MLIRCPUBackend::mlirOptLevel, /*sizeLevel=*/0, MLIRCPUBackend::targetMachine.get()); MLIRCPUBackend::mlirOptLevel, /*sizeLevel=*/0, MLIRCPUBackend::targetMachine.get());
auto maybeEngine = mlir::ExecutionEngine::create( auto maybeEngine = mlir::ExecutionEngine::create(
...@@ -70,14 +72,14 @@ void MLIRCPURuntime::run_internal(std::vector<void*>& externalTensors) ...@@ -70,14 +72,14 @@ void MLIRCPURuntime::run_internal(std::vector<void*>& externalTensors)
NGRAPH_CHECK(maybeEngine, "failed to construct an execution engine"); NGRAPH_CHECK(maybeEngine, "failed to construct an execution engine");
m_engine = std::move(maybeEngine.get()); m_engine = std::move(maybeEngine.get());
bindArguments(externalTensors); bindArguments(args);
execute(); execute();
cleanup(); cleanup();
} }
// Binds MLIR function arguments to the proper values. This includes externally allocated tensors // Binds MLIR function arguments to the proper values. This includes externally allocated tensors
// helpers to be used inside the function. // helpers to be used inside the function.
void MLIRCPURuntime::bindArguments(std::vector<void*>& externalTensors) void MLIRCPURuntime::bindArguments(const std::vector<MemRefArg>& args)
{ {
NGRAPH_CHECK(m_module, "MLIR module is not ready."); NGRAPH_CHECK(m_module, "MLIR module is not ready.");
...@@ -85,13 +87,17 @@ void MLIRCPURuntime::bindArguments(std::vector<void*>& externalTensors) ...@@ -85,13 +87,17 @@ void MLIRCPURuntime::bindArguments(std::vector<void*>& externalTensors)
NGRAPH_CHECK(func && !func.getBlocks().empty(), "Function not found"); NGRAPH_CHECK(func && !func.getBlocks().empty(), "Function not found");
// Set external arguments // Set external arguments
m_externalTensors = &externalTensors; m_externalTensors = &args;
// Create list with a type-erased double pointer for each invocation arguments. // Create list with a type-erased double pointer for each invocation arguments.
// We currently use 'allocateMemrefArgs', which creates the arguments list per call ABI (see // We currently use 'allocateMemrefArgs', which creates the arguments list per call ABI (see
// comment below). // comment below).
// StaticMemRef is just a struct with the actual pointer to the data. // StaticMemRef is just a struct with the actual pointer to the data.
for (auto i = 0; i < m_externalTensors->size(); i++)
{
m_ranks.push_back((*m_externalTensors)[i].m_shape.size());
}
auto expectedArguments = allocateMemrefArgs(); auto expectedArguments = allocateMemrefArgs();
NGRAPH_CHECK(expectedArguments.size(), "Arguments can't be created"); NGRAPH_CHECK(expectedArguments.size(), "Arguments can't be created");
m_invokeArgs = std::move(expectedArguments); m_invokeArgs = std::move(expectedArguments);
...@@ -103,8 +109,14 @@ void MLIRCPURuntime::bindArguments(std::vector<void*>& externalTensors) ...@@ -103,8 +109,14 @@ void MLIRCPURuntime::bindArguments(std::vector<void*>& externalTensors)
for (size_t i = 0, numArgs = m_invokeArgs.size(); i < numArgs; ++i) for (size_t i = 0, numArgs = m_invokeArgs.size(); i < numArgs; ++i)
{ {
auto* memRefArg = *(reinterpret_cast<StaticMemRef**>(m_invokeArgs[i])); auto* memRefArg = *(reinterpret_cast<StaticMemRef**>(m_invokeArgs[i]));
memRefArg->allocatedPtr = (*m_externalTensors)[i]; memRefArg->allocatedPtr = (*m_externalTensors)[i].m_tensor;
memRefArg->alignedPtr = (*m_externalTensors)[i]; memRefArg->alignedPtr = (*m_externalTensors)[i].m_tensor;
auto rank = m_ranks[i];
for (auto j = 0; j < rank; j++)
{
memRefArg->shapeAndStrides[j] = (*m_externalTensors)[i].m_shape[j];
memRefArg->shapeAndStrides[rank + j] = (*m_externalTensors)[i].m_strides[j];
}
} }
} }
...@@ -128,6 +140,7 @@ void MLIRCPURuntime::execute() ...@@ -128,6 +140,7 @@ void MLIRCPURuntime::execute()
void MLIRCPURuntime::cleanup() void MLIRCPURuntime::cleanup()
{ {
// Free void double pointer arguments without freeing external tensor data. // Free void double pointer arguments without freeing external tensor data.
int i = 0;
for (auto* arg : m_invokeArgs) for (auto* arg : m_invokeArgs)
{ {
auto* memRefArg = *(reinterpret_cast<StaticMemRef**>(arg)); auto* memRefArg = *(reinterpret_cast<StaticMemRef**>(arg));
...@@ -148,7 +161,7 @@ SmallVector<void*, 8> MLIRCPURuntime::allocateMemrefArgs() ...@@ -148,7 +161,7 @@ SmallVector<void*, 8> MLIRCPURuntime::allocateMemrefArgs()
SmallVector<void*, 8> args; SmallVector<void*, 8> args;
for (auto i = 0; i < m_externalTensors->size(); i++) for (auto i = 0; i < m_externalTensors->size(); i++)
{ {
auto descriptor = allocateMemrefDescriptor(); auto descriptor = allocateMemrefDescriptor(m_ranks[i]);
StaticMemRef** arg = reinterpret_cast<StaticMemRef**>(malloc(sizeof(StaticMemRef*))); StaticMemRef** arg = reinterpret_cast<StaticMemRef**>(malloc(sizeof(StaticMemRef*)));
*arg = descriptor; *arg = descriptor;
args.push_back(arg); args.push_back(arg);
...@@ -156,13 +169,17 @@ SmallVector<void*, 8> MLIRCPURuntime::allocateMemrefArgs() ...@@ -156,13 +169,17 @@ SmallVector<void*, 8> MLIRCPURuntime::allocateMemrefArgs()
return args; return args;
} }
StaticMemRef* MLIRCPURuntime::allocateMemrefDescriptor() StaticMemRef* MLIRCPURuntime::allocateMemrefDescriptor(size_t rank)
{ {
// We only use StaticMemRef because that's what MLIR currently offers. // We only use StaticMemRef because that's what MLIR currently offers.
// We should expand this with different types and dynamic MemRefs // We should expand this with different types and dynamic MemRefs
auto* descriptor = reinterpret_cast<StaticMemRef*>(malloc(sizeof(StaticMemRef))); // We allocate 2 * rank * sizeof(int64_t) for the last element "int64_t shapeAndStrides[]"
// in StaticMemRef because shape and strides each needs rank * sizeof(int64_t).
auto* descriptor =
reinterpret_cast<StaticMemRef*>(malloc(sizeof(StaticMemRef) + 2 * rank * sizeof(int64_t)));
NGRAPH_CHECK(descriptor != nullptr, "NULL MemRef descriptor"); NGRAPH_CHECK(descriptor != nullptr, "NULL MemRef descriptor");
descriptor->allocatedPtr = nullptr; descriptor->allocatedPtr = nullptr;
descriptor->alignedPtr = nullptr; descriptor->alignedPtr = nullptr;
descriptor->offset = 0;
return descriptor; return descriptor;
} }
...@@ -37,7 +37,16 @@ namespace ngraph ...@@ -37,7 +37,16 @@ namespace ngraph
{ {
void* allocatedPtr; void* allocatedPtr;
void* alignedPtr; void* alignedPtr;
int64_t offset;
int64_t shapeAndStrides[];
}; };
struct UnrankedMemRef
{
int64_t rank;
StaticMemRef* memRefDescPtr;
};
/// A CPU Runtime is an MLIR runtime that owns an MLIR context and a module /// A CPU Runtime is an MLIR runtime that owns an MLIR context and a module
/// The module should be in LLVM dialect and ready to be lowered via an MLIR /// The module should be in LLVM dialect and ready to be lowered via an MLIR
/// ExecutionEngine. The runtime owns the context and must out-live any MLIR /// ExecutionEngine. The runtime owns the context and must out-live any MLIR
...@@ -46,12 +55,12 @@ namespace ngraph ...@@ -46,12 +55,12 @@ namespace ngraph
{ {
public: public:
/// Executes a pre-compiled subgraph /// Executes a pre-compiled subgraph
void run(void* args) override; void run(const std::vector<MemRefArg>& args) override;
private: private:
void run_internal(std::vector<void*>& externalTensors); void run_internal(const std::vector<MemRefArg>& args);
// Bind external tensors to MLIR module entry point // Bind external tensors to MLIR module entry point
void bindArguments(std::vector<void*>& externalTensors); void bindArguments(const std::vector<MemRefArg>& args);
// Invokes an MLIR module entry point with bound arguments // Invokes an MLIR module entry point with bound arguments
void execute(); void execute();
// Cleans up allocated args // Cleans up allocated args
...@@ -61,14 +70,15 @@ namespace ngraph ...@@ -61,14 +70,15 @@ namespace ngraph
llvm::SmallVector<void*, 8> allocateMemrefArgs(); llvm::SmallVector<void*, 8> allocateMemrefArgs();
/// Helper to allocate a mem ref object. Handles static shapes only for now. /// Helper to allocate a mem ref object. Handles static shapes only for now.
StaticMemRef* allocateMemrefDescriptor(); StaticMemRef* allocateMemrefDescriptor(size_t);
private: private:
// Pointers to externally allocated memory for sub-graph's input and output tensors. // Pointers to externally allocated memory for sub-graph's input and output tensors.
std::vector<void*>* m_externalTensors; const std::vector<MemRefArg>* m_externalTensors;
// Arguments for the MLIR function generated for the nGraph sub-graph. // Arguments for the MLIR function generated for the nGraph sub-graph.
llvm::SmallVector<void*, 8> m_invokeArgs; llvm::SmallVector<void*, 8> m_invokeArgs;
std::unique_ptr<mlir::ExecutionEngine> m_engine; std::unique_ptr<mlir::ExecutionEngine> m_engine;
std::vector<size_t> m_ranks;
}; };
} }
} }
......
...@@ -33,6 +33,13 @@ namespace ngraph ...@@ -33,6 +33,13 @@ namespace ngraph
{ {
namespace ngmlir namespace ngmlir
{ {
struct MemRefArg
{
void* m_tensor;
std::vector<size_t> m_shape;
std::vector<size_t> m_strides;
};
/// Base class for an MLIR runtime. An MLIR runtime owns the MLIR Context and owns /// Base class for an MLIR runtime. An MLIR runtime owns the MLIR Context and owns
/// the final compiled module. It supports invoking the module with specific arguments /// the final compiled module. It supports invoking the module with specific arguments
class MLIRRuntime class MLIRRuntime
...@@ -43,7 +50,7 @@ namespace ngraph ...@@ -43,7 +50,7 @@ namespace ngraph
/// Overload with module op /// Overload with module op
void set_module(mlir::ModuleOp& module) { m_module = module; } void set_module(mlir::ModuleOp& module) { m_module = module; }
/// Executes a pre-compiled subgraph /// Executes a pre-compiled subgraph
virtual void run(void* args) = 0; virtual void run(const std::vector<MemRefArg>& args) = 0;
/// Get the MLIR module that this runtime owns /// Get the MLIR module that this runtime owns
mlir::OwningModuleRef& get_module() { return m_module; } mlir::OwningModuleRef& get_module() { return m_module; }
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "ngraph/op/add.hpp" #include "ngraph/op/add.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/op/dot.hpp" #include "ngraph/op/dot.hpp"
#include "ngraph/op/fused/matmul.hpp"
#include "ngraph/op/multiply.hpp" #include "ngraph/op/multiply.hpp"
#include "ngraph/op/util/broadcasting.hpp" #include "ngraph/op/util/broadcasting.hpp"
......
...@@ -43,30 +43,69 @@ namespace ngraph ...@@ -43,30 +43,69 @@ namespace ngraph
// Tensors haven't been allocated yet so we have to keep a pointer to the pointer // Tensors haven't been allocated yet so we have to keep a pointer to the pointer
// that will hold the future memory address. // that will hold the future memory address.
std::vector<size_t> buffer_indices; std::vector<size_t> buffer_indices;
std::vector<std::vector<size_t>> shape_vec;
std::vector<std::vector<size_t>> strides_vec;
for (const TensorViewWrapper& arg : args) for (const TensorViewWrapper& arg : args)
{ {
auto buffer_index = external_function->get_buffer_index(arg.get_name()); auto buffer_index = external_function->get_buffer_index(arg.get_name());
buffer_indices.push_back(buffer_index); buffer_indices.push_back(buffer_index);
// Get shape and strides
auto tensor_shape = arg.get_shape();
std::vector<size_t> shape(tensor_shape.size());
for (auto i = 0; i < tensor_shape.size(); i++)
{
shape[i] = tensor_shape[i];
}
shape_vec.push_back(shape);
auto tensor_strides = arg.get_strides();
std::vector<size_t> strides(tensor_strides.size());
for (auto i = 0; i < tensor_strides.size(); i++)
{
strides[i] = tensor_strides[i];
}
strides_vec.push_back(strides);
} }
for (const TensorViewWrapper& result : out) for (const TensorViewWrapper& result : out)
{ {
auto buffer_index = external_function->get_buffer_index(result.get_name()); auto buffer_index = external_function->get_buffer_index(result.get_name());
buffer_indices.push_back(buffer_index); buffer_indices.push_back(buffer_index);
// Get shape and strides
auto tensor_shape = result.get_shape();
std::vector<size_t> shape(tensor_shape.size());
for (auto i = 0; i < tensor_shape.size(); i++)
{
shape[i] = tensor_shape[i];
}
shape_vec.push_back(shape);
auto tensor_strides = result.get_strides();
std::vector<size_t> strides(tensor_strides.size());
for (auto i = 0; i < tensor_strides.size(); i++)
{
strides[i] = tensor_strides[i];
}
strides_vec.push_back(strides);
} }
// Create functor that will be executed to compile and run this CompiledKernel. // Create functor that will be executed to compile and run this CompiledKernel.
// Note that 'double_ptr_args' must be captured by value since it's a local var. // Note that 'double_ptr_args' must be captured by value since it's a local var.
auto functor = [node, buffer_indices](CPURuntimeContext* ctx, auto functor = [node, buffer_indices, shape_vec, strides_vec](
CPUExecutionContext* ectx) { CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
// MLIR requires a list of type-erased pointer to arguments. Tensors must have // MLIR requires a list of type-erased pointer to arguments. Tensors must have
// been allocated at this point so we can get rid of the extra reference. // been allocated at this point so we can get rid of the extra reference.
std::vector<void*> ptr_args; std::vector<MemRefArg> mem_ref_arg_vec;
int i = 0;
for (auto& buffer_index : buffer_indices) for (auto& buffer_index : buffer_indices)
{ {
ptr_args.push_back(ctx->buffer_data[buffer_index]); MemRefArg mem_ref_arg;
mem_ref_arg.m_tensor = ctx->buffer_data[buffer_index];
mem_ref_arg.m_shape = shape_vec[i];
mem_ref_arg.m_strides = strides_vec[i];
mem_ref_arg_vec.push_back(mem_ref_arg);
i++;
} }
// Compile nodes within the CompiledKernel op. // Compile nodes within the CompiledKernel op.
CompiledKernel* compiled_kernel = CompiledKernel* compiled_kernel =
static_cast<CompiledKernel*>(const_cast<Node*>(node)); static_cast<CompiledKernel*>(const_cast<Node*>(node));
...@@ -97,13 +136,13 @@ namespace ngraph ...@@ -97,13 +136,13 @@ namespace ngraph
mlir_backend.codegen(); mlir_backend.codegen();
// Store module into runtime, and invoke. // Store module into runtime, and invoke.
mlir_runtime.set_module(mlir_backend.get_module()); mlir_runtime.set_module(mlir_backend.get_module());
mlir_runtime.run(&ptr_args); mlir_runtime.run(mem_ref_arg_vec);
} }
else else
{ {
// We have found a cached runtime, just invoke. // We have found a cached runtime, just invoke.
MLIRCPURuntime& mlir_runtime = it->second; MLIRCPURuntime& mlir_runtime = it->second;
mlir_runtime.run(&ptr_args); mlir_runtime.run(mem_ref_arg_vec);
} }
}; };
......
...@@ -87,8 +87,10 @@ ...@@ -87,8 +87,10 @@
#include "ngraph/op/floor.hpp" #include "ngraph/op/floor.hpp"
#include "ngraph/op/fused/conv_fused.hpp" #include "ngraph/op/fused/conv_fused.hpp"
#include "ngraph/op/fused/gelu.hpp" #include "ngraph/op/fused/gelu.hpp"
#include "ngraph/op/fused/gemm.hpp"
#include "ngraph/op/fused/group_conv.hpp" #include "ngraph/op/fused/group_conv.hpp"
#include "ngraph/op/fused/lstm_cell.hpp" #include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/fused/matmul.hpp"
#include "ngraph/op/fused/softmax_crossentropy.hpp" #include "ngraph/op/fused/softmax_crossentropy.hpp"
#include "ngraph/op/gather.hpp" #include "ngraph/op/gather.hpp"
#include "ngraph/op/gather_nd.hpp" #include "ngraph/op/gather_nd.hpp"
...@@ -1187,7 +1189,22 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes( ...@@ -1187,7 +1189,22 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(
auto dex = is_direct_execution(); auto dex = is_direct_execution();
auto is_supported = [dex](const Node& node) { auto is_supported = [dex](const Node& node) {
#ifdef NGRAPH_MLIR_ENABLE
if (std::getenv("NGRAPH_MLIR") != nullptr && std::getenv("NGRAPH_MLIR_CALLBACK") != nullptr)
{
if (typeid(ngraph::op::MatMul) == typeid(node) &&
node.get_input_element_type(0) == element::f32)
{
return true;
}
if (typeid(ngraph::op::Gemm) == typeid(node) &&
node.get_input_element_type(0) == element::f32)
{
return true;
}
}
#endif
// this checks averts the decomposition of LSTMCell // this checks averts the decomposition of LSTMCell
// we will map LSTMCell to LSTM CPU op in the later // we will map LSTMCell to LSTM CPU op in the later
// graph pass // graph pass
......
...@@ -321,6 +321,7 @@ set(MULTI_TEST_SRC ...@@ -321,6 +321,7 @@ set(MULTI_TEST_SRC
backend/logical_or.in.cpp backend/logical_or.in.cpp
backend/logical_xor.in.cpp backend/logical_xor.in.cpp
backend/lrn.in.cpp backend/lrn.in.cpp
backend/matmul.in.cpp
backend/max.in.cpp backend/max.in.cpp
backend/maximum.in.cpp backend/maximum.in.cpp
backend/min.in.cpp backend/min.in.cpp
......
...@@ -1024,6 +1024,26 @@ NGRAPH_TEST(${BACKEND_NAME}, gemm) ...@@ -1024,6 +1024,26 @@ NGRAPH_TEST(${BACKEND_NAME}, gemm)
test_case.run(); test_case.run();
} }
NGRAPH_TEST(${BACKEND_NAME}, gemm_C)
{
auto A = make_shared<op::Parameter>(element::f32, Shape{3, 6});
auto B = make_shared<op::Parameter>(element::f32, Shape{6, 4});
auto C = make_shared<op::Parameter>(element::f32, Shape{3, 4});
auto gemm_func = make_shared<op::Gemm>(A, B, C);
auto function = make_shared<Function>(NodeVector{gemm_func}, ParameterVector{A, B, C});
auto test_case = test::NgraphTestCase(function, "${BACKEND_NAME}");
// A
test_case.add_input<float>(vector<float>(18, 1));
// B
test_case.add_input<float>(vector<float>(24, 2));
// C
test_case.add_input<float>(vector<float>(12, 1));
// output
test_case.add_expected_output<float>(Shape{3, 4}, vector<float>(12, 13));
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, gemm_broadcast_input_C) NGRAPH_TEST(${BACKEND_NAME}, gemm_broadcast_input_C)
{ {
auto A = make_shared<op::Parameter>(element::f32, Shape{3, 6}); auto A = make_shared<op::Parameter>(element::f32, Shape{3, 6});
...@@ -1041,6 +1061,48 @@ NGRAPH_TEST(${BACKEND_NAME}, gemm_broadcast_input_C) ...@@ -1041,6 +1061,48 @@ NGRAPH_TEST(${BACKEND_NAME}, gemm_broadcast_input_C)
test_case.add_input<float>(vector<float>{1}); test_case.add_input<float>(vector<float>{1});
// output // output
test_case.add_expected_output<float>(Shape{3, 4}, vector<float>(12, 7)); test_case.add_expected_output<float>(Shape{3, 4}, vector<float>(12, 7));
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, gemm_broadcast_axes_0_input_C)
{
auto A = make_shared<op::Parameter>(element::f32, Shape{3, 6});
auto B = make_shared<op::Parameter>(element::f32, Shape{6, 4});
auto C = make_shared<op::Parameter>(element::f32, Shape{1, 4});
auto gemm_func = make_shared<op::Gemm>(A, B, C, 0.5);
auto function = make_shared<Function>(NodeVector{gemm_func}, ParameterVector{A, B, C});
auto test_case = test::NgraphTestCase(function, "${BACKEND_NAME}");
// A
test_case.add_input<float>(vector<float>(18, 1));
// B
test_case.add_input<float>(vector<float>(24, 2));
// C
test_case.add_input<float>(vector<float>{1, 2, 3, 4});
// output
test_case.add_expected_output<float>(Shape{3, 4},
vector<float>{7, 8, 9, 10, 7, 8, 9, 10, 7, 8, 9, 10});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, gemm_broadcast_axes_1_input_C)
{
auto A = make_shared<op::Parameter>(element::f32, Shape{3, 6});
auto B = make_shared<op::Parameter>(element::f32, Shape{6, 4});
auto C = make_shared<op::Parameter>(element::f32, Shape{3, 1});
auto gemm_func = make_shared<op::Gemm>(A, B, C, 0.5);
auto function = make_shared<Function>(NodeVector{gemm_func}, ParameterVector{A, B, C});
auto test_case = test::NgraphTestCase(function, "${BACKEND_NAME}");
// A
test_case.add_input<float>(vector<float>(18, 1));
// B
test_case.add_input<float>(vector<float>(24, 2));
// C
test_case.add_input<float>(vector<float>(3, 1));
// output
test_case.add_expected_output<float>(Shape{3, 4}, vector<float>(12, 7));
test_case.run();
} }
NGRAPH_TEST(${BACKEND_NAME}, fused_clamp) NGRAPH_TEST(${BACKEND_NAME}, fused_clamp)
......
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <algorithm>
#include <cinttypes>
#include <cmath>
#include <cstdlib>
#include <random>
#include <string>
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "util/all_close.hpp"
#include "util/all_close_f.hpp"
#include "util/ndarray.hpp"
#include "util/test_control.hpp"
#include "util/test_tools.hpp"
using namespace std;
using namespace ngraph;
static string s_manifest = "${MANIFEST}";
NGRAPH_TEST(${BACKEND_NAME}, matmul_2x0_0x2)
{
Shape shape_a{2, 0};
Shape shape_b{0, 2};
Shape shape_r{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
auto B = make_shared<op::Parameter>(element::f32, shape_b);
auto f = make_shared<Function>(make_shared<op::MatMul>(A, B), ParameterVector{A, B});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape_a);
copy_data(a, vector<float>{});
auto b = backend->create_tensor(element::f32, shape_b);
copy_data(b, vector<float>{});
auto result = backend->create_tensor(element::f32, shape_r);
// Overwrite the initial result vector to make sure we're not just coincidentally getting the
// right value.
copy_data(result, vector<float>{2112, 2112, 2112, 2112});
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a, b});
EXPECT_TRUE(test::all_close_f((vector<float>{0, 0, 0, 0}), read_vector<float>(result)));
}
NGRAPH_TEST(${BACKEND_NAME}, matmul_0x2_2x0)
{
Shape shape_a{0, 2};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
Shape shape_b{2, 0};
auto B = make_shared<op::Parameter>(element::f32, shape_b);
Shape shape_r{0, 0};
auto f = make_shared<Function>(make_shared<op::MatMul>(A, B), ParameterVector{A, B});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape_a);
copy_data(a, vector<float>{});
auto b = backend->create_tensor(element::f32, shape_b);
copy_data(b, vector<float>{});
auto result = backend->create_tensor(element::f32, shape_r);
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a, b});
EXPECT_TRUE(test::all_close_f((vector<float>{}), read_vector<float>(result)));
}
NGRAPH_TEST(${BACKEND_NAME}, matmul_3x2_2x0)
{
Shape shape_a{3, 2};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
Shape shape_b{2, 0};
auto B = make_shared<op::Parameter>(element::f32, shape_b);
Shape shape_r{3, 0};
auto f = make_shared<Function>(make_shared<op::MatMul>(A, B), ParameterVector{A, B});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape_a);
copy_data(a, vector<float>{1, 2, 3, 4, 5, 6});
auto b = backend->create_tensor(element::f32, shape_b);
copy_data(b, vector<float>{});
auto result = backend->create_tensor(element::f32, shape_r);
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a, b});
EXPECT_TRUE(test::all_close_f((vector<float>{}), read_vector<float>(result)));
}
NGRAPH_TEST(${BACKEND_NAME}, matmul_2x2_2x2)
{
Shape shape{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::Parameter>(element::f32, shape);
Shape shape_r{2, 2};
auto f = make_shared<Function>(make_shared<op::MatMul>(A, B), ParameterVector{A, B});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{1, 2, 3, 4});
auto b = backend->create_tensor(element::f32, shape);
copy_data(b, vector<float>{5, 6, 7, 8});
auto result = backend->create_tensor(element::f32, shape_r);
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a, b});
EXPECT_TRUE(test::all_close_f((vector<float>{19, 22, 43, 50}), read_vector<float>(result)));
}
NGRAPH_TEST(${BACKEND_NAME}, matmul_2x3_3x3)
{
Shape shape_in1{2, 3};
Shape shape_in2{3, 3};
Shape shape_out{2, 3};
auto A = make_shared<op::Parameter>(element::f32, shape_in1);
auto B = make_shared<op::Parameter>(element::f32, shape_in2);
auto matmul = make_shared<op::MatMul>(A, B, false, false);
auto f = make_shared<Function>(matmul, ParameterVector{A, B});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
shared_ptr<runtime::Tensor> a = backend->create_tensor(element::f32, shape_in1);
shared_ptr<runtime::Tensor> b = backend->create_tensor(element::f32, shape_in2);
shared_ptr<runtime::Tensor> result = backend->create_tensor(element::f32, shape_out);
copy_data(a, vector<float>{1.f, 2.f, 3.f, 4.f, 5.f, 6.f});
copy_data(b, vector<float>{1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f});
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a, b});
EXPECT_TRUE(test::all_close_f(read_vector<float>(result),
vector<float>{30.f, 36.f, 42.f, 66.f, 81.f, 96.f}));
}
NGRAPH_TEST(${BACKEND_NAME}, matmul_2x3_3x3_int64)
{
Shape shape_in1{2, 3};
Shape shape_in2{3, 3};
Shape shape_out{2, 3};
auto A = make_shared<op::Parameter>(element::i64, shape_in1);
auto B = make_shared<op::Parameter>(element::i64, shape_in2);
auto matmul = make_shared<op::MatMul>(A, B, false, false);
auto f = make_shared<Function>(matmul, ParameterVector{A, B});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
shared_ptr<runtime::Tensor> a = backend->create_tensor(element::i64, shape_in1);
shared_ptr<runtime::Tensor> b = backend->create_tensor(element::i64, shape_in2);
shared_ptr<runtime::Tensor> result = backend->create_tensor(element::i64, shape_out);
copy_data(a, vector<int64_t>{1, 2, 3, 4, 5, 6});
copy_data(b, vector<int64_t>{1, 2, 3, 4, 5, 6, 7, 8, 9});
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a, b});
EXPECT_TRUE(
test::all_close(read_vector<int64_t>(result), vector<int64_t>{30, 36, 42, 66, 81, 96}));
}
NGRAPH_TEST(${BACKEND_NAME}, matmul_3x2_3x3_transpose)
{
Shape shape_in1{3, 2};
Shape shape_in2{3, 3};
Shape shape_out{2, 3};
auto A = make_shared<op::Parameter>(element::f32, shape_in1);
auto B = make_shared<op::Parameter>(element::f32, shape_in2);
auto matmul = make_shared<op::MatMul>(A, B, true, false);
auto f = make_shared<Function>(matmul, ParameterVector{A, B});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
shared_ptr<runtime::Tensor> a = backend->create_tensor(element::f32, shape_in1);
shared_ptr<runtime::Tensor> b = backend->create_tensor(element::f32, shape_in2);
shared_ptr<runtime::Tensor> result = backend->create_tensor(element::f32, shape_out);
copy_data(a, vector<float>{1.f, 4.f, 2.f, 5.f, 3.f, 6.f});
copy_data(b, vector<float>{1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f});
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a, b});
EXPECT_TRUE(test::all_close_f(read_vector<float>(result),
vector<float>{30.f, 36.f, 42.f, 66.f, 81.f, 96.f}));
}
NGRAPH_TEST(${BACKEND_NAME}, matmul_3x2_2x3_transpose)
{
Shape shape_in1{3, 2};
Shape shape_in2{2, 3};
Shape shape_out{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape_in1);
auto B = make_shared<op::Parameter>(element::f32, shape_in2);
auto matmul = make_shared<op::MatMul>(A, B, true, true);
auto f = make_shared<Function>(matmul, ParameterVector{A, B});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
shared_ptr<runtime::Tensor> a = backend->create_tensor(element::f32, shape_in1);
shared_ptr<runtime::Tensor> b = backend->create_tensor(element::f32, shape_in2);
shared_ptr<runtime::Tensor> result = backend->create_tensor(element::f32, shape_out);
copy_data(a, vector<float>{1.f, 4.f, 2.f, 5.f, 3.f, 6.f});
copy_data(b, vector<float>{1.f, 3.f, 5.f, 2.f, 4.f, 6.f});
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a, b});
EXPECT_TRUE(
test::all_close_f(read_vector<float>(result), vector<float>{22.f, 28.f, 49.f, 64.f}));
}
// RUN: ngraph-opt %s -convert-ngraph-to-affine -split-input-file | FileCheck %s
// Verify that operations using callbacks are properly converted to standard call.
// -----
// Softmax Op
// CHECK-LABEL: func @simple_softmax
// CHECK: %[[C1:.*]] = constant 0 : i64
// CHECK: %[[C2:.*]] = constant {{[0-9]+}} : i64
// CHECK: %0 = memref_cast %arg0 : memref<2x3xf32> to memref<*xf32>
// CHECK: %1 = memref_cast %arg2 : memref<2x3xf32> to memref<*xf32>
// CHECK: call @__mlir_callback_1_input(%0, %1, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, i64, i64) -> ()
func @simple_softmax(%arg0: !ng.tensor<2x3xf32>, %arg1: !ng.tensor<1x!ng.i64>) -> !ng.tensor<2x3xf32> {
%0 = "ng.softmax"(%arg0) {axes = [0]} : (!ng.tensor<2x3xf32>) -> !ng.tensor<2x3xf32>
"ng.return"(%0) : (!ng.tensor<2x3xf32>) -> ()
}
// -----
// Gemm Op
// CHECK-LABEL: func @simple_gemm
// CHECK: %[[C1:.*]] = constant 0 : i64
// CHECK: %[[C2:.*]] = constant {{[0-9]+}} : i64
// CHECK: %0 = memref_cast %arg0 : memref<3x6xf32> to memref<*xf32>
// CHECK: %1 = memref_cast %arg1 : memref<6x4xf32> to memref<*xf32>
// CHECK: %2 = memref_cast %arg2 : memref<3x4xf32> to memref<*xf32>
// CHECK: %3 = memref_cast %arg3 : memref<3x4xf32> to memref<*xf32>
// CHECK: call @__mlir_callback_3_inputs(%0, %1, %2, %3, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, memref<*xf32>, memref<*xf32>, i64, i64) -> ()
func @simple_gemm(%arg0: !ng.tensor<3x6xf32>, %arg1: !ng.tensor<6x4xf32>, %arg2: !ng.tensor<3x4xf32>) -> !ng.tensor<3x4xf32> {
%0 = "ng.gemm"(%arg0, %arg1, %arg2) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, transA = false, transB = false} : (!ng.tensor<3x6xf32>, !ng.tensor<6x4xf32>, !ng.tensor<3x4xf32>) -> !ng.tensor<3x4xf32>
"ng.return"(%0) : (!ng.tensor<3x4xf32>) -> ()
}
// -----
// MatMul Op
// CHECK-LABEL: func @simple_matmul
// CHECK: %[[C1:.*]] = constant 0 : i64
// CHECK: %[[C2:.*]] = constant {{[0-9]+}} : i64
// CHECK: %0 = memref_cast %arg0 : memref<3x2xf32> to memref<*xf32>
// CHECK: %1 = memref_cast %arg1 : memref<2x3xf32> to memref<*xf32>
// CHECK: %2 = memref_cast %arg2 : memref<2x2xf32> to memref<*xf32>
// CHECK: call @__mlir_callback_2_inputs(%0, %1, %2, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, memref<*xf32>, i64, i64) -> ()
func @simple_matmul(%arg0: !ng.tensor<3x2xf32>, %arg1: !ng.tensor<2x3xf32>) -> !ng.tensor<2x2xf32> {
%0 = "ng.matmul"(%arg0, %arg1) {transposeA = true, transposeB = true} : (!ng.tensor<3x2xf32>, !ng.tensor<2x3xf32>) -> !ng.tensor<2x2xf32>
"ng.return"(%0) : (!ng.tensor<2x2xf32>) -> ()
}
// -----
// AvePool Op
// CHECK-LABEL: func @simple_avgpool
// CHECK: %0 = memref_cast %arg0 : memref<2x1x3x3xf32> to memref<*xf32>
// CHECK: %1 = memref_cast %arg1 : memref<2x1x3x3xf32> to memref<*xf32>
// CHECK: %[[C1:.*]] = constant 0 : i64
// CHECK: %[[C2:.*]] = constant {{[0-9]+}} : i64
// CHECK: call @__mlir_callback_1_input(%0, %1, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, i64, i64) -> ()
func @simple_avgpool(%arg0: !ng.tensor<2x1x3x3xf32>) -> !ng.tensor<2x1x3x3xf32> {
%0 = "ng.avgPool"(%arg0) {includePadding = true, padAbove = [1, 1], padBelow = [0, 0], windowMovementStrides = [1, 1], windowShape = [2, 2]} : (!ng.tensor<2x1x3x3xf32>) -> !ng.tensor<2x1x3x3xf32>
"ng.return"(%0) : (!ng.tensor<2x1x3x3xf32>) -> ()
}
// -----
// AvgPoolBackprop Op
// CHECK-LABEL: func @simple_avgpoolbackprop
// CHECK: %0 = memref_cast %arg0 : memref<2x2x2x2xf32> to memref<*xf32>
// CHECK: %1 = memref_cast %arg1 : memref<2x2x3x3xf32> to memref<*xf32>
// CHECK: %[[C1:.*]] = constant 0 : i64
// CHECK: %[[C2:.*]] = constant {{[0-9]+}} : i64
// CHECK: call @__mlir_callback_1_input(%0, %1, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, i64, i64) -> ()
func @simple_avgpoolbackprop(%arg0: !ng.tensor<2x2x2x2xf32>) -> !ng.tensor<2x2x3x3xf32> {
%0 = "ng.avgPoolBackprop"(%arg0) {forwardArgShape = [2, 2, 3, 3], includePadding = false, padAbove = [0, 0], padBelow = [0, 0], windowMovementStrides = [1, 1], windowShape = [2, 2]} : (!ng.tensor<2x2x2x2xf32>) -> !ng.tensor<2x2x3x3xf32>
"ng.return"(%0) : (!ng.tensor<2x2x3x3xf32>) -> ()
}
// -----
// MaxPool Op
// CHECK-LABEL: func @simple_maxpool
// CHECK: %0 = memref_cast %arg0 : memref<64x3x7x8x10xf32> to memref<*xf32>
// CHECK: %1 = memref_cast %arg1 : memref<64x3x9x6x5xf32> to memref<*xf32>
// CHECK: %[[C1:.*]] = constant 0 : i64
// CHECK: %[[C2:.*]] = constant {{[0-9]+}} : i64
// CHECK: call @__mlir_callback_1_input(%0, %1, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, i64, i64) -> ()
func @simple_maxpool(%arg0: !ng.tensor<64x3x7x8x10xf32>) -> !ng.tensor<64x3x9x6x5xf32> {
%0 = "ng.maxPool"(%arg0) {padAbove = [6, 4, 5], padBelow = [5, 6, 4], windowMovementStrides = [2, 3, 4], windowShape = [2, 3, 2]} : (!ng.tensor<64x3x7x8x10xf32>) -> !ng.tensor<64x3x9x6x5xf32>
"ng.return"(%0) : (!ng.tensor<64x3x9x6x5xf32>) -> ()
}
// -----
// MaxPoolBackprop Op
// CHECK-LABEL: func @simple_maxpoolbackprop
// CHECK: %0 = memref_cast %arg0 : memref<2x2x5x5xf32> to memref<*xf32>
// CHECK: %1 = memref_cast %arg1 : memref<2x2x4x3xf32> to memref<*xf32>
// CHECK: %2 = memref_cast %arg2 : memref<2x2x5x5xf32> to memref<*xf32>
// CHECK: %[[C1:.*]] = constant 0 : i64
// CHECK: %[[C2:.*]] = constant {{[0-9]+}} : i64
// CHECK: call @__mlir_callback_2_inputs(%0, %1, %2, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, memref<*xf32>, i64, i64) -> ()
func @simple_maxpoolbackprop(%arg0: !ng.tensor<2x2x5x5xf32>, %arg1: !ng.tensor<2x2x4x3xf32>) -> !ng.tensor<2x2x5x5xf32> {
%0 = "ng.maxPoolBackprop"(%arg0, %arg1) {padAbove = [0, 0], padBelow = [0, 0], windowMovementStrides = [1, 1], windowShape = [2, 3]} : (!ng.tensor<2x2x5x5xf32>, !ng.tensor<2x2x4x3xf32>) -> !ng.tensor<2x2x5x5xf32>
"ng.return"(%0) : (!ng.tensor<2x2x5x5xf32>) -> ()
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment