Unverified Commit a8a9bcb5 authored by Amy Zhuang's avatar Amy Zhuang Committed by GitHub

[MLIR] Use mkldnn callback for ConvBias. (#4205)

* [MLIR] Use mkldnn callback for ConvBias.

* Add try catch.

Fix opAttrsVec.

Add rank check for Gemm and MatMul.

* Fix merge error.

* Fix a bug.

* Fix lit test.

* Modify unit test.

* Fix merge error.

* Address PR feedback.

* Address PR feedback.

* Insert callback_init function to module.

* Fix lit tests.

* Fix a bug.

* Use a set of GlobalOps for attributes.

* Address PR feedback.

* Address PR feedback.

* Fix merge error.

* Fix style error.

* Fix style error.
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent 3dce6fdb
...@@ -31,6 +31,7 @@ MLIR_OP(NGAvgPoolOp , false ) ...@@ -31,6 +31,7 @@ MLIR_OP(NGAvgPoolOp , false )
MLIR_OP(NGAvgPoolBackpropOp , false ) MLIR_OP(NGAvgPoolBackpropOp , false )
MLIR_OP(NGConcatOp , true ) MLIR_OP(NGConcatOp , true )
MLIR_OP(NGConvolutionOp , false ) MLIR_OP(NGConvolutionOp , false )
MLIR_OP(NGConvBiasOp , false )
MLIR_OP(NGDivOp , true ) MLIR_OP(NGDivOp , true )
MLIR_OP(NGDotOp , false ) MLIR_OP(NGDotOp , false )
MLIR_OP(NGGatherOp , false ) MLIR_OP(NGGatherOp , false )
......
...@@ -942,8 +942,8 @@ def NGDepthToSpaceOp : ...@@ -942,8 +942,8 @@ def NGDepthToSpaceOp :
def NGConvBiasOp : def NGConvBiasOp :
NG_OneResult_Op<"convBias", [NoSideEffect, DeclareOpInterfaceMethods<FusedOp>]>, NG_OneResult_Op<"convBias", [NoSideEffect, DeclareOpInterfaceMethods<FusedOp>]>,
Arguments<(ins NG_TensorType:$images, NG_TensorType:$filters, NG_TensorType:$bias, Arguments<(ins NG_TensorType:$images, NG_TensorType:$filters, NG_TensorType:$bias,
I64ArrayAttr:$strides, I64ArrayAttr:$padBelow, I64ArrayAttr:$padAbove, I64ArrayAttr:$strides, I64ArrayAttr:$dilation, I64ArrayAttr:$padBelow,
DefaultValuedAttr<BoolAttr, "false">:$withRelu)> I64ArrayAttr:$padAbove, DefaultValuedAttr<BoolAttr, "false">:$withRelu)>
{ {
let summary = "Convolution Bias Op"; let summary = "Convolution Bias Op";
let description = "Convolution + bias forward prop for batched convolution operation."; let description = "Convolution + bias forward prop for batched convolution operation.";
...@@ -967,9 +967,10 @@ def NGConvBiasOp : ...@@ -967,9 +967,10 @@ def NGConvBiasOp :
let extraClassDeclaration = [{ let extraClassDeclaration = [{
void setStrides(const ArrayAttr& attr) { this->setAttr("strides", attr); } void setStrides(const ArrayAttr& attr) { this->setAttr("strides", attr); }
void setDilation(const ArrayAttr& attr) { this->setAttr("dilation", attr); }
void setPadAbove(const ArrayAttr& attr) { this->setAttr("padAbove", attr); } void setPadAbove(const ArrayAttr& attr) { this->setAttr("padAbove", attr); }
void setPadBelow(const ArrayAttr& attr) { this->setAttr("padBelow", attr); } void setPadBelow(const ArrayAttr& attr) { this->setAttr("padBelow", attr); }
void setWithRelu(const Attribute& attr) {this->setAttr("withRelu", attr); } void setWithRelu(const Attribute& attr) { this->setAttr("withRelu", attr); }
}]; }];
} }
......
...@@ -12,6 +12,7 @@ MLIR_OP(Divide) ...@@ -12,6 +12,7 @@ MLIR_OP(Divide)
MLIR_OP(Dot) MLIR_OP(Dot)
MLIR_OP(Concat) MLIR_OP(Concat)
MLIR_OP(Convolution) MLIR_OP(Convolution)
MLIR_OP(ConvolutionBias)
MLIR_OP(Gather) MLIR_OP(Gather)
MLIR_OP(Gemm) MLIR_OP(Gemm)
MLIR_OP(Greater) MLIR_OP(Greater)
......
...@@ -420,6 +420,47 @@ void MLIRSubgraphExtractionPass::sanity_check(std::shared_ptr<Function> func, No ...@@ -420,6 +420,47 @@ void MLIRSubgraphExtractionPass::sanity_check(std::shared_ptr<Function> func, No
} }
} }
// Check if convolution related nodes such as Convolution, ConvolutionBias,
// ConvolutionRelu, ... can use callback.
template <typename T>
static bool can_use_mkldnn_conv_callback(ngraph::Node* node)
{
auto convolution = static_cast<const T*>(node);
auto arg0_rank = node->get_input_shape(0).size();
auto dilation = convolution->get_data_dilation_strides();
if (std::any_of(dilation.begin(), dilation.end(), [](size_t s) { return s != 1; }))
{
return false;
}
// MKLDNN doesnt support negative padding
auto pad_above = convolution->get_padding_above();
if (std::any_of(pad_above.begin(), pad_above.end(), [](size_t s) { return s < 0; }))
{
return false;
}
auto pad_below = convolution->get_padding_below();
if (std::any_of(pad_below.begin(), pad_below.end(), [](size_t s) { return s < 0; }))
{
return false;
}
if (arg0_rank != 3 && arg0_rank != 4 && arg0_rank != 5)
{
return false;
}
// Only support f32 for now
if (node->get_input_element_type(0) != ngraph::element::f32 ||
node->get_input_element_type(1) != ngraph::element::f32 ||
node->get_output_element_type(0) != ngraph::element::f32)
{
return false;
}
return true;
}
bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node) bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node)
{ {
if (is_type<Parameter>(node) || is_type<Result>(node)) if (is_type<Parameter>(node) || is_type<Result>(node))
...@@ -474,6 +515,16 @@ bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node ...@@ -474,6 +515,16 @@ bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node
std::all_of(window_dilation.begin(), window_dilation.end(), is_one); std::all_of(window_dilation.begin(), window_dilation.end(), is_one);
} }
if (is_type<ngraph::op::ConvolutionBias>(node))
{
// ConvBias is only supported through callback
if (!getenv_bool("NGRAPH_MLIR_CALLBACK"))
{
return false;
}
return can_use_mkldnn_conv_callback<ngraph::op::ConvolutionBias>(node.get());
}
// MKLDNN only supports softmax across single axis // MKLDNN only supports softmax across single axis
if (auto softmax = as_type_ptr<ngraph::op::Softmax>(node)) if (auto softmax = as_type_ptr<ngraph::op::Softmax>(node))
{ {
...@@ -552,7 +603,8 @@ bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node ...@@ -552,7 +603,8 @@ bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node
if (is_type<ngraph::op::MatMul>(node)) if (is_type<ngraph::op::MatMul>(node))
{ {
// MatMul is only supported through callback // MatMul is only supported through callback
if (!getenv_bool("NGRAPH_MLIR_CALLBACK")) if (!getenv_bool("NGRAPH_MLIR_CALLBACK") || node->get_input_shape(0).size() != 2 ||
node->get_input_shape(1).size() != 2)
{ {
return false; return false;
} }
...@@ -561,7 +613,8 @@ bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node ...@@ -561,7 +613,8 @@ bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node
if (is_type<ngraph::op::Gemm>(node)) if (is_type<ngraph::op::Gemm>(node))
{ {
// Gemm is only supported through callback // Gemm is only supported through callback
if (!getenv_bool("NGRAPH_MLIR_CALLBACK")) if (!getenv_bool("NGRAPH_MLIR_CALLBACK") || node->get_input_shape(0).size() != 2 ||
node->get_input_shape(1).size() != 2)
{ {
return false; return false;
} }
......
...@@ -478,6 +478,21 @@ mlir::Operation* NgDialectConversionPass::COMPILE_OP_DECL(ngraph::op::GroupConvo ...@@ -478,6 +478,21 @@ mlir::Operation* NgDialectConversionPass::COMPILE_OP_DECL(ngraph::op::GroupConvo
return op; return op;
} }
template <>
mlir::Operation* NgDialectConversionPass::COMPILE_OP_DECL(ngraph::op::ConvolutionBias)
{
mlir::Operation* op = NgDialectObj.createGenericOp<mlir::NGConvBiasOp>(ngNode);
auto convNode = static_cast<const ngraph::op::ConvolutionBias*>(ngNode);
auto convOp = llvm::cast<mlir::NGConvBiasOp>(op);
convOp.setStrides(NgDialectObj.getShapeAsAttr(convNode->get_window_movement_strides()));
convOp.setDilation(NgDialectObj.getShapeAsAttr(convNode->get_window_dilation_strides()));
convOp.setPadBelow(NgDialectObj.getShapeAsAttr(convNode->get_padding_below()));
convOp.setPadAbove(NgDialectObj.getShapeAsAttr(convNode->get_padding_above()));
convOp.setWithRelu(NgDialectObj.m_builder.getBoolAttr(convNode->with_relu()));
return op;
}
template <> template <>
mlir::Operation* NgDialectConversionPass::COMPILE_OP_DECL(ngraph::op::AvgPool) mlir::Operation* NgDialectConversionPass::COMPILE_OP_DECL(ngraph::op::AvgPool)
{ {
......
...@@ -94,6 +94,16 @@ namespace ngraph ...@@ -94,6 +94,16 @@ namespace ngraph
}; };
// These structs and union are used to pass attributes to callbacks. // These structs and union are used to pass attributes to callbacks.
template <int N>
struct convAttrs
{
bool withRelu;
int64_t windowStrides[N];
int64_t windowDilation[N];
int64_t padBelow[N];
int64_t padAbove[N];
};
template <int N> template <int N>
struct poolAttrs struct poolAttrs
{ {
...@@ -119,8 +129,22 @@ namespace ngraph ...@@ -119,8 +129,22 @@ namespace ngraph
BroadcastType broadcastHint; BroadcastType broadcastHint;
}; };
enum class AttrsType
{
INT = 0,
CONV1D,
CONV2D,
CONV3D,
POOL2D,
POOL3D,
GEMM
};
union opAttrs { union opAttrs {
int intAttr; int64_t intAttr;
convAttrs<1> convAttrs1d;
convAttrs<2> convAttrs2d;
convAttrs<3> convAttrs3d;
poolAttrs<2> poolAttrs2d; poolAttrs<2> poolAttrs2d;
poolAttrs<3> poolAttrs3d; poolAttrs<3> poolAttrs3d;
gemmAttrs gemmAttrs2d; gemmAttrs gemmAttrs2d;
......
...@@ -60,27 +60,29 @@ llvm::cl::opt<bool> clEnableBarePtrMemRefLowering( ...@@ -60,27 +60,29 @@ llvm::cl::opt<bool> clEnableBarePtrMemRefLowering(
llvm::cl::init(false), llvm::cl::init(false),
llvm::cl::desc("Enable the lowering of MemRefs to LLVM bare pointers")); llvm::cl::desc("Enable the lowering of MemRefs to LLVM bare pointers"));
void MLIRCPURuntime::run(const std::vector<MemRefArg>& args) void MLIRCPURuntime::run(const std::vector<MemRefArg>& args, bool firstIteration)
{ {
// run_internal(*reinterpret_cast<std::vector<void*>*>(args), shapeVec, stridesVec); run_internal(args, firstIteration);
run_internal(args);
} }
void MLIRCPURuntime::run_internal(const std::vector<MemRefArg>& args) void MLIRCPURuntime::run_internal(const std::vector<MemRefArg>& args, bool firstIteration)
{ {
// Create an MLIR execution engine. We use a null MLIR pass manager for now to make sure we // Create an MLIR execution engine. We use a null MLIR pass manager for now to make sure we
// don't run MLIR passes that were already run. We also pass a default transformer created with // don't run MLIR passes that were already run. We also pass a default transformer created with
// the default or user-provided optimization level. // the default or user-provided optimization level.
if (!m_engine)
{
auto llvmTransformer = mlir::makeOptimizingTransformer( auto llvmTransformer = mlir::makeOptimizingTransformer(
MLIRCPUBackend::mlirOptLevel, /*sizeLevel=*/0, MLIRCPUBackend::targetMachine.get()); MLIRCPUBackend::mlirOptLevel, /*sizeLevel=*/0, MLIRCPUBackend::targetMachine.get());
auto maybeEngine = mlir::ExecutionEngine::create( auto maybeEngine = mlir::ExecutionEngine::create(
m_module.get(), llvmTransformer, MLIRCPUBackend::mlirOptLevel); m_module.get(), llvmTransformer, MLIRCPUBackend::mlirOptLevel);
NGRAPH_CHECK(maybeEngine, "failed to construct an execution engine"); NGRAPH_CHECK(maybeEngine, "failed to construct an execution engine");
m_engine = std::move(maybeEngine.get()); m_engine = std::move(maybeEngine.get());
}
bindArguments(args); bindArguments(args);
execute(); execute(firstIteration);
cleanup(); cleanup();
} }
...@@ -90,7 +92,8 @@ void MLIRCPURuntime::bindArguments(const std::vector<MemRefArg>& args) ...@@ -90,7 +92,8 @@ void MLIRCPURuntime::bindArguments(const std::vector<MemRefArg>& args)
{ {
NGRAPH_CHECK(m_module, "MLIR module is not ready."); NGRAPH_CHECK(m_module, "MLIR module is not ready.");
auto func = m_module->lookupSymbol<mlir::LLVM::LLVMFuncOp>("_mlir_ciface_main"); auto name = clEnableBarePtrMemRefLowering ? "main" : "_mlir_ciface_main";
auto func = m_module->lookupSymbol<mlir::LLVM::LLVMFuncOp>(name);
NGRAPH_CHECK(func && !func.getBlocks().empty(), "Function not found"); NGRAPH_CHECK(func && !func.getBlocks().empty(), "Function not found");
// Set external arguments // Set external arguments
...@@ -138,21 +141,46 @@ void MLIRCPURuntime::bindArguments(const std::vector<MemRefArg>& args) ...@@ -138,21 +141,46 @@ void MLIRCPURuntime::bindArguments(const std::vector<MemRefArg>& args)
} }
// Lowers standard dialect to LLVM dialect and uses the MLIR execution engine to execute the code. // Lowers standard dialect to LLVM dialect and uses the MLIR execution engine to execute the code.
void MLIRCPURuntime::execute() void MLIRCPURuntime::execute(bool firstIteration)
{ {
// Invoke the JIT-compiled function with the arguments. Note that, for API // Invoke the JIT-compiled function with the arguments. Note that, for API
// uniformity reasons, it takes a list of type-erased pointers to arguments. // uniformity reasons, it takes a list of type-erased pointers to arguments.
// Please, note that 'invoke' method is overloaded with a parameter pack version. // Please, note that 'invoke' method is overloaded with a parameter pack version.
// Make sure the MutableArrayRef version is invoked. // Make sure the MutableArrayRef version is invoked.
if (!clEnableBarePtrMemRefLowering)
{
if (firstIteration)
{
auto invocationResult = m_engine->invoke("_mlir_ciface_callback_init");
if (clDumpObjectFile)
{
m_engine->dumpToObjectFile(clObjectFilename.empty() ? "jitted_mlir.o"
: clObjectFilename.getValue());
}
NGRAPH_CHECK(!invocationResult,
"JIT invocation of '_mlir_ciface_callback_init' failed\n");
}
auto invocationResult = auto invocationResult =
m_engine->invoke("_mlir_ciface_main", llvm::MutableArrayRef<void*>(m_invokeArgs)); m_engine->invoke("_mlir_ciface_main", llvm::MutableArrayRef<void*>(m_invokeArgs));
if (clDumpObjectFile) if (clDumpObjectFile)
{ {
m_engine->dumpToObjectFile(clObjectFilename.empty() ? "jitted_mlir.o" m_engine->dumpToObjectFile(clObjectFilename.empty() ? "jitted_mlir.o"
: clObjectFilename.getValue()); : clObjectFilename.getValue());
} }
NGRAPH_CHECK(!invocationResult, "JIT invocation of '_mlir_ciface_main' failed\n"); NGRAPH_CHECK(!invocationResult, "JIT invocation of '_mlir_ciface_main' failed\n");
}
else
{
auto invocationResult =
m_engine->invoke("main", llvm::MutableArrayRef<void*>(m_invokeArgs));
if (clDumpObjectFile)
{
m_engine->dumpToObjectFile(clObjectFilename.empty() ? "jitted_mlir.o"
: clObjectFilename.getValue());
}
NGRAPH_CHECK(!invocationResult, "JIT invocation of 'main' failed\n");
}
} }
void MLIRCPURuntime::cleanup() void MLIRCPURuntime::cleanup()
......
...@@ -55,14 +55,14 @@ namespace ngraph ...@@ -55,14 +55,14 @@ namespace ngraph
{ {
public: public:
/// Executes a pre-compiled subgraph /// Executes a pre-compiled subgraph
void run(const std::vector<MemRefArg>& args) override; void run(const std::vector<MemRefArg>& args, bool firstIteration) override;
private: private:
void run_internal(const std::vector<MemRefArg>& args); void run_internal(const std::vector<MemRefArg>& args, bool firstIteration);
// Bind external tensors to MLIR module entry point // Bind external tensors to MLIR module entry point
void bindArguments(const std::vector<MemRefArg>& args); void bindArguments(const std::vector<MemRefArg>& args);
// Invokes an MLIR module entry point with bound arguments // Invokes an MLIR module entry point with bound arguments
void execute(); void execute(bool firstIteration);
// Cleans up allocated args // Cleans up allocated args
void cleanup(); void cleanup();
......
...@@ -50,7 +50,7 @@ namespace ngraph ...@@ -50,7 +50,7 @@ namespace ngraph
/// Overload with module op /// Overload with module op
void set_module(mlir::ModuleOp& module) { m_module = module; } void set_module(mlir::ModuleOp& module) { m_module = module; }
/// Executes a pre-compiled subgraph /// Executes a pre-compiled subgraph
virtual void run(const std::vector<MemRefArg>& args) = 0; virtual void run(const std::vector<MemRefArg>& args, bool firstIteration) = 0;
/// Get the MLIR module that this runtime owns /// Get the MLIR module that this runtime owns
mlir::OwningModuleRef& get_module() { return m_module; } mlir::OwningModuleRef& get_module() { return m_module; }
......
...@@ -136,13 +136,13 @@ namespace ngraph ...@@ -136,13 +136,13 @@ namespace ngraph
mlir_backend.codegen(); mlir_backend.codegen();
// Store module into runtime, and invoke. // Store module into runtime, and invoke.
mlir_runtime.set_module(mlir_backend.get_module()); mlir_runtime.set_module(mlir_backend.get_module());
mlir_runtime.run(mem_ref_arg_vec); mlir_runtime.run(mem_ref_arg_vec, true /*firstIteration*/);
} }
else else
{ {
// We have found a cached runtime, just invoke. // We have found a cached runtime, just invoke.
MLIRCPURuntime& mlir_runtime = it->second; MLIRCPURuntime& mlir_runtime = it->second;
mlir_runtime.run(mem_ref_arg_vec); mlir_runtime.run(mem_ref_arg_vec, false /*firstIteration*/);
} }
}; };
......
...@@ -1191,13 +1191,15 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes( ...@@ -1191,13 +1191,15 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(
if (getenv_bool("NGRAPH_MLIR") && getenv_bool("NGRAPH_MLIR_CALLBACK")) if (getenv_bool("NGRAPH_MLIR") && getenv_bool("NGRAPH_MLIR_CALLBACK"))
{ {
if (typeid(ngraph::op::MatMul) == typeid(node) && if (typeid(ngraph::op::MatMul) == typeid(node) &&
node.get_input_element_type(0) == element::f32) node.get_input_element_type(0) == element::f32 &&
node.get_input_shape(0).size() == 2 && node.get_input_shape(1).size() == 2)
{ {
return true; return true;
} }
if (typeid(ngraph::op::Gemm) == typeid(node) && if (typeid(ngraph::op::Gemm) == typeid(node) &&
node.get_input_element_type(0) == element::f32) node.get_input_element_type(0) == element::f32 &&
node.get_input_shape(0).size() == 2 && node.get_input_shape(1).size() == 2)
{ {
return true; return true;
} }
......
...@@ -992,8 +992,12 @@ TEST(cpu_fusion, conv_horizontal_fusion) ...@@ -992,8 +992,12 @@ TEST(cpu_fusion, conv_horizontal_fusion)
auto cpu_results = execute(cpu_f, args, "CPU"); auto cpu_results = execute(cpu_f, args, "CPU");
EXPECT_TRUE(test::all_close(cpu_results.at(0), int_results.at(0))); EXPECT_TRUE(test::all_close(cpu_results.at(0), int_results.at(0)));
size_t cpu_ck = count_ops_of_type<op::CompiledKernel>(cpu_f);
if (!cpu_ck)
{
size_t cpu_cb = count_ops_of_type<op::ConvolutionBias>(cpu_f); size_t cpu_cb = count_ops_of_type<op::ConvolutionBias>(cpu_f);
ASSERT_EQ(cpu_cb, 1); ASSERT_EQ(cpu_cb, 1);
}
} }
// ConvolutionBiasAdd relies on an in-place fused MKLDNN kernel. // ConvolutionBiasAdd relies on an in-place fused MKLDNN kernel.
......
...@@ -1052,7 +1052,12 @@ TEST(cpu_test, thread_safe_calls_convolution_2d_2items) ...@@ -1052,7 +1052,12 @@ TEST(cpu_test, thread_safe_calls_convolution_2d_2items)
unset_environment("NGRAPH_CPU_CONCURRENCY"); unset_environment("NGRAPH_CPU_CONCURRENCY");
} }
TEST(cpu_test, constant_convertlayout) // This test checks if a ConverLayout node is inserted before the ConvolutionBias node.
// Since MLIR supports ConvolutionBias through callback, the data layout conversion is done in
// callback.
// There is no ConvertLayout node when MLIR and MLIR CALLBACK are enabled.
// Thus this test is disabled with MLIR enabled.
TEST(cpu_test, MLIR_DISABLE_TEST(constant_convertlayout))
{ {
Shape data_shape{1, 64, 56, 56}; Shape data_shape{1, 64, 56, 56};
auto data = make_shared<op::Parameter>(element::f32, data_shape); auto data = make_shared<op::Parameter>(element::f32, data_shape);
......
...@@ -250,8 +250,8 @@ func @depthToSpace(%arg0: !ng.tensor<1x8x2x2xf32>) -> !ng.tensor<1x2x4x4xf32> ...@@ -250,8 +250,8 @@ func @depthToSpace(%arg0: !ng.tensor<1x8x2x2xf32>) -> !ng.tensor<1x2x4x4xf32>
//CHECK-LABEL: func @convBias //CHECK-LABEL: func @convBias
func @convBias(%arg0: !ng.tensor<1x3x2xf32>, %arg1: !ng.tensor<2x3x1xf32>, %arg2: !ng.tensor<2xf32>) -> (!ng.tensor<1x2x2xf32>) func @convBias(%arg0: !ng.tensor<1x3x2xf32>, %arg1: !ng.tensor<2x3x1xf32>, %arg2: !ng.tensor<2xf32>) -> (!ng.tensor<1x2x2xf32>)
{ {
//CHECK: %{{.*}} = "ng.convBias"(%{{.*}}, %{{.*}}, %{{.*}}) {padAbove = [0], padBelow = [0], strides = [1]} : (!ng.tensor<1x3x2xf32>, !ng.tensor<2x3x1xf32>, !ng.tensor<2xf32>) -> !ng.tensor<1x2x2xf32> //CHECK: %{{.*}} = "ng.convBias"(%{{.*}}, %{{.*}}, %{{.*}}) {dilation = [1], padAbove = [0], padBelow = [0], strides = [1]} : (!ng.tensor<1x3x2xf32>, !ng.tensor<2x3x1xf32>, !ng.tensor<2xf32>) -> !ng.tensor<1x2x2xf32>
%0 = "ng.convBias"(%arg0, %arg1, %arg2) {padAbove=[0], padBelow=[0], strides=[1]} %0 = "ng.convBias"(%arg0, %arg1, %arg2) {dilation=[1], padAbove=[0], padBelow=[0], strides=[1]}
: (!ng.tensor<1x3x2xf32>, !ng.tensor<2x3x1xf32>, !ng.tensor<2xf32>) -> !ng.tensor<1x2x2xf32> : (!ng.tensor<1x3x2xf32>, !ng.tensor<2x3x1xf32>, !ng.tensor<2xf32>) -> !ng.tensor<1x2x2xf32>
"ng.return"(%0) : (!ng.tensor<1x2x2xf32>) -> () "ng.return"(%0) : (!ng.tensor<1x2x2xf32>) -> ()
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment