Unverified Commit a9268c8f authored by Pruthvi's avatar Pruthvi Committed by GitHub

Merge branch 'master' into pruthvi/memory_allocator

parents 825d5df0 01302f82
// ******************************************************************************
// Copyright 2018-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// ******************************************************************************
try{ if(LABEL.trim() == "") {throw new Exception();} }catch(Exception e){LABEL="onnx && ci"}; echo "${LABEL}"
NGRPAH_REPOSITORY = "https://github.com/NervanaSystems/ngraph.git"
NGRAPH_COMMIT_HASH = "${ghprbActualCommit}" // particular nGraph PR commit hash
ONNX_REPOSITORY = "https://github.com/NervanaSystems/onnxruntime.git"
ONNX_RUNTIME_BRANCH = "release"
def main(){
timeout(activity: true, time: 15) {
try{
stage("CloneRepos"){
CloneRepos()
}
stage("Apply Patch"){
ApplyPatch()
}
stage("Onnx Models"){
BuildAndTest()
}
}
catch(e) {
// Set result to ABORTED if exception contains exit code of a process interrupted by SIGTERM
if ("$e".contains("143")) {
currentBuild.result = "ABORTED"
} else {
currentBuild.result = "FAILURE"
}
}
stage("Clean"){
Clean()
}
}
}
def CloneRepos() {
dir("ngraph"){
checkout([
$class: 'GitSCM',
branches: [[name: "${NGRAPH_COMMIT_HASH}"]],
doGenerateSubmoduleConfigurations: false,
extensions: [[
$class: 'SubmoduleOption',
disableSubmodules: false,
parentCredentials: true,
recursiveSubmodules: true,
reference: '',
trackingSubmodules: false,
timeout: 15
]],
submoduleCfg: [],
userRemoteConfigs: [[
refspec: '+refs/pull/*:refs/remotes/origin/pr/*',
url: "${NGRPAH_REPOSITORY}"
]]
])
}
dir("onnxruntime") {
checkout([
$class: 'GitSCM',
branches: [[name: "${ONNX_RUNTIME_BRANCH}"]],
doGenerateSubmoduleConfigurations: false,
extensions: [[
$class: 'SubmoduleOption',
disableSubmodules: false,
parentCredentials: true,
recursiveSubmodules: true,
reference: '',
trackingSubmodules: false,
timeout: 15
]],
submoduleCfg: [],
userRemoteConfigs: [[
url: "${ONNX_REPOSITORY}"
]]
])
}
}
def ApplyPatch(){
dir("onnxruntime"){
echo "Update cmake/external/ngraph.cmake with ${NGRAPH_COMMIT_HASH}"
sh """
sed -i 's/set(ngraph_TAG ".*")/set(ngraph_TAG "${NGRAPH_COMMIT_HASH}")/g' cmake/external/ngraph.cmake
grep -q "${NGRAPH_COMMIT_HASH}" cmake/external/ngraph.cmake
"""
echo "Add proxy to tools/ci_build/github/linux/docker/Dockerfile.ubuntu"
sh """
sed -i 's|{HTTP_PROXY}|${env.http_proxy}|g' ../ngraph/.ci/onnx/onnxruntime/proxy.patch
sed -i 's|{SOCKS_PROXY}|${env.socks_proxy}|g' ../ngraph/.ci/onnx/onnxruntime/proxy.patch
grep -q "${env.http_proxy}" ../ngraph/.ci/onnx/onnxruntime/proxy.patch
git apply ../ngraph/.ci/onnx/onnxruntime/proxy.patch
"""
}
}
def BuildAndTest(){
dir("onnxruntime"){
sh "mkdir -p `pwd`/build/models && chmod 777 `pwd`/build/models"
sh """
//!/bin/bash
./tools/ci_build/github/linux/run_dockerbuild.sh \
-o ubuntu16.04 \
-d ngraph \
-r `pwd`/build -x '--use_ngraph --use_full_protobuf --test_data_url https://onnxruntimetestdata.blob.core.windows.net/models/20190327.zip --test_data_checksum 45166d81c021c8aae212b53c92101792'
"""
}
}
def Clean(){
deleteDir()
}
node(LABEL) {
main()
}
diff --git a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu
index bdff95e1..cd9c0008 100644
--- a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu
+++ b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu
@@ -3,6 +3,18 @@ FROM ubuntu:${OS_VERSION}
ARG PYTHON_VERSION=3.5
+ENV http_proxy={HTTP_PROXY}
+ENV socks_proxy={SOCKS_PROXY}
+ENV https_proxy={HTTP_PROXY}
+ENV ftp_proxy={HTTP_PROXY}
+ENV rsync_proxy={HTTP_PROXY}
+ENV no_proxy=intel.com,.intel.com,localhost
+ENV HTTP_PROXY={HTTP_PROXY}
+ENV HTTPS_PROXY={HTTP_PROXY}
+ENV FTP_PROXY={HTTP_PROXY}
+ENV SOCKS_PROXY={SOCKS_PROXY}
+ENV NO_PROXY=intel.com,.intel.com,localhost
+
ADD scripts /tmp/scripts
RUN /tmp/scripts/install_ubuntu.sh -p ${PYTHON_VERSION} && /tmp/scripts/install_deps.sh && rm -rf /tmp/scripts
Contributor Guidelines Contributor Guidelines
====================== ======================
https://ngraph.nervanasys.com/docs/latest/project/code-contributor-README.html The latest version of this file can be found at:
https://ngraph.nervanasys.com/docs/latest/project/contribution-guide.html
License License
......
...@@ -20,8 +20,8 @@ set(MLIR_LLVM_REPO_URL https://github.com/llvm/llvm-project.git) ...@@ -20,8 +20,8 @@ set(MLIR_LLVM_REPO_URL https://github.com/llvm/llvm-project.git)
set(MLIR_REPO_URL https://github.com/tensorflow/mlir.git) set(MLIR_REPO_URL https://github.com/tensorflow/mlir.git)
# Change these commit IDs to move to latest stable versions # Change these commit IDs to move to latest stable versions
set(MLIR_LLVM_COMMIT_ID bb2b527) set(MLIR_LLVM_COMMIT_ID c0cad98)
set(MLIR_COMMIT_ID 49f7efc) set(MLIR_COMMIT_ID 82d5084)
set(MLIR_PROJECT_ROOT ${CMAKE_CURRENT_BINARY_DIR}/mlir_project) set(MLIR_PROJECT_ROOT ${CMAKE_CURRENT_BINARY_DIR}/mlir_project)
set(MLIR_LLVM_ROOT ${MLIR_PROJECT_ROOT}/llvm-projects) set(MLIR_LLVM_ROOT ${MLIR_PROJECT_ROOT}/llvm-projects)
set(MLIR_SOURCE_DIR ${MLIR_LLVM_ROOT}/llvm/projects/mlir) set(MLIR_SOURCE_DIR ${MLIR_LLVM_ROOT}/llvm/projects/mlir)
......
...@@ -56,6 +56,7 @@ if (NGRAPH_MLIR_ENABLE) ...@@ -56,6 +56,7 @@ if (NGRAPH_MLIR_ENABLE)
MLIRExecutionEngine MLIRExecutionEngine
MLIRIR MLIRIR
MLIRLLVMIR MLIRLLVMIR
MLIRStandardToLLVM
MLIRParser MLIRParser
MLIRPass MLIRPass
MLIRTargetLLVMIR MLIRTargetLLVMIR
......
...@@ -34,11 +34,12 @@ ...@@ -34,11 +34,12 @@
#include <llvm/Support/MemoryBuffer.h> #include <llvm/Support/MemoryBuffer.h>
#include <llvm/Support/SourceMgr.h> #include <llvm/Support/SourceMgr.h>
#include <llvm/Support/TargetSelect.h> #include <llvm/Support/TargetSelect.h>
#include <mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h>
#include <mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h>
#include <mlir/ExecutionEngine/ExecutionEngine.h> #include <mlir/ExecutionEngine/ExecutionEngine.h>
#include <mlir/ExecutionEngine/MemRefUtils.h> #include <mlir/ExecutionEngine/MemRefUtils.h>
#include <mlir/ExecutionEngine/OptUtils.h> #include <mlir/ExecutionEngine/OptUtils.h>
#include <mlir/LLVMIR/LLVMDialect.h> #include <mlir/LLVMIR/LLVMDialect.h>
#include <mlir/LLVMIR/Transforms.h>
#include <mlir/Pass/PassManager.h> #include <mlir/Pass/PassManager.h>
#include <mlir/Target/LLVMIR.h> #include <mlir/Target/LLVMIR.h>
#include <mlir/Transforms/DialectConversion.h> #include <mlir/Transforms/DialectConversion.h>
...@@ -50,6 +51,7 @@ ...@@ -50,6 +51,7 @@
using llvm::SmallVector; using llvm::SmallVector;
using llvm::StringRef; using llvm::StringRef;
using llvm::make_unique; using llvm::make_unique;
using namespace ngraph::runtime::ngmlir; using namespace ngraph::runtime::ngmlir;
#define COMPILE_OP_DECL(op_name) \ #define COMPILE_OP_DECL(op_name) \
...@@ -75,7 +77,7 @@ void MLIRCompiler::init_mlir() ...@@ -75,7 +77,7 @@ void MLIRCompiler::init_mlir()
if (!initialized) if (!initialized)
{ {
mlir::registerDialect<mlir::NGDialect>(); mlir::registerDialect<mlir::NGraphOpsDialect>();
// Register any LLVM command line options // Register any LLVM command line options
llvm::cl::ParseEnvironmentOptions("ngraph", "MLIR_LLVM_OPTIONS", ""); llvm::cl::ParseEnvironmentOptions("ngraph", "MLIR_LLVM_OPTIONS", "");
initialized = true; initialized = true;
...@@ -133,7 +135,7 @@ void MLIRCompiler::build_ng_dialect_module() ...@@ -133,7 +135,7 @@ void MLIRCompiler::build_ng_dialect_module()
} }
// create builder // create builder
m_builder = llvm::make_unique<mlir::FuncBuilder>(function.get()); m_builder = llvm::make_unique<mlir::OpBuilder>(function->getBody());
build_ng_dialect(); build_ng_dialect();
m_module->getFunctions().push_back(function.release()); m_module->getFunctions().push_back(function.release());
if (failed(m_module->verify())) if (failed(m_module->verify()))
...@@ -359,10 +361,14 @@ void MLIRCompiler::execute() ...@@ -359,10 +361,14 @@ void MLIRCompiler::execute()
NGRAPH_CHECK(m_module, "MLIR module is not ready."); NGRAPH_CHECK(m_module, "MLIR module is not ready.");
// Lower Standard dialect to LLVM dialect. // Lower Standard dialect to LLVM dialect.
auto converter = mlir::createStdToLLVMConverter(); mlir::LLVMTypeConverter llvm_converter(&m_context);
auto r = converter->convert(m_module.get()); OwningRewritePatternList patterns;
(void)r; mlir::populateStdToLLVMConversionPatterns(llvm_converter, patterns);
NGRAPH_CHECK(succeeded(r), "second conversion failed");
mlir::ConversionTarget target(m_context);
target.addLegalDialect<mlir::LLVM::LLVMDialect>();
auto result = applyConversionPatterns(*m_module, target, llvm_converter, std::move(patterns));
NGRAPH_CHECK(succeeded(result), "Standard to LLVM dialect conversion failed");
dump_mlir_module("LLVM-IR Dialect Dump:"); dump_mlir_module("LLVM-IR Dialect Dump:");
......
...@@ -132,7 +132,7 @@ namespace ngraph ...@@ -132,7 +132,7 @@ namespace ngraph
mlir::MLIRContext m_context; mlir::MLIRContext m_context;
std::unique_ptr<mlir::Module> m_module; std::unique_ptr<mlir::Module> m_module;
std::unique_ptr<mlir::FuncBuilder> m_builder; std::unique_ptr<mlir::OpBuilder> m_builder;
std::unique_ptr<mlir::ExecutionEngine> m_engine; std::unique_ptr<mlir::ExecutionEngine> m_engine;
using TensorToInfo = std::pair<descriptor::Tensor*, TensorInfo>; using TensorToInfo = std::pair<descriptor::Tensor*, TensorInfo>;
......
...@@ -21,8 +21,8 @@ ...@@ -21,8 +21,8 @@
using namespace mlir; using namespace mlir;
NGDialect::NGDialect(mlir::MLIRContext* ctx) NGraphOpsDialect::NGraphOpsDialect(mlir::MLIRContext* ctx)
: mlir::Dialect("ng", ctx) : mlir::Dialect(getDialectNamespace(), ctx)
{ {
addTypes<NGTensorType>(); addTypes<NGTensorType>();
addTypes<NGIntegerType>(); addTypes<NGIntegerType>();
...@@ -34,7 +34,7 @@ NGDialect::NGDialect(mlir::MLIRContext* ctx) ...@@ -34,7 +34,7 @@ NGDialect::NGDialect(mlir::MLIRContext* ctx)
>(); >();
} }
void NGDialect::printType(mlir::Type type, raw_ostream& os) const void NGraphOpsDialect::printType(mlir::Type type, raw_ostream& os) const
{ {
switch (type.getKind()) switch (type.getKind())
{ {
......
...@@ -25,15 +25,17 @@ ...@@ -25,15 +25,17 @@
#include "ngraph/check.hpp" #include "ngraph/check.hpp"
namespace mlir namespace mlir
{ {
class NGDialect : public mlir::Dialect class NGraphOpsDialect : public mlir::Dialect
{ {
public: public:
explicit NGDialect(mlir::MLIRContext* ctx); explicit NGraphOpsDialect(mlir::MLIRContext* ctx);
mlir::Type parseType(llvm::StringRef tyData, mlir::Location loc) const override mlir::Type parseType(llvm::StringRef tyData, mlir::Location loc) const override
{ {
NGRAPH_CHECK(false, "Unsupported type parsing."); NGRAPH_CHECK(false, "Unsupported type parsing.");
return mlir::Type(); return mlir::Type();
} }
void printType(mlir::Type type, llvm::raw_ostream& os) const override; void printType(mlir::Type type, llvm::raw_ostream& os) const override;
static StringRef getDialectNamespace() { return "ng"; }
}; };
} }
...@@ -41,31 +41,34 @@ namespace ...@@ -41,31 +41,34 @@ namespace
class DialectLoweringPass; class DialectLoweringPass;
/// Base class for nGraph operation conversions to affine/standard dialect. Provides
/// conversion patterns with an access to the DialectLoweringPass which holds the state of the
/// conversion.
class NGraphOpLowering : public ConversionPattern
{
public:
NGraphOpLowering(StringRef rootOpName, MLIRContext* context, DialectLoweringPass& pass)
: ConversionPattern(rootOpName, /*benefit=*/1, context)
, m_pass(pass){};
protected:
// Back-reference to the lowering pass which contains the lowering state, including the
// nGraph type converter.
DialectLoweringPass& m_pass;
};
#include "op_lowerers.inc" #include "op_lowerers.inc"
/// Use Dialect Converson Framework /// Conversion from types in the nGraph dialect to the Standard dialect.
class DialectLowerer : public DialectConversion class NGraphTypeConverter : public TypeConverter
{ {
public: public:
DialectLowerer(DialectLoweringPass& pass) NGraphTypeConverter()
: DialectConversion() : TypeConverter()
, m_pass(pass)
{ {
} }
Type convertType(Type t) override; Type convertType(Type t) override;
protected:
// Initialize the list of converters.
void initConverters(OwningRewritePatternList& patterns, MLIRContext* mlirContext) override
{
RewriteListBuilder<NGAddOpConversion, NGDotOpConversion, NGReturnOpConversion>::build(
patterns, mlirContext, m_pass);
}
private:
DialectLoweringPass& m_pass;
llvm::BumpPtrAllocator allocator;
}; };
/// Dialect Lowering Pass to affine ops /// Dialect Lowering Pass to affine ops
...@@ -73,14 +76,17 @@ namespace ...@@ -73,14 +76,17 @@ namespace
{ {
public: public:
DialectLoweringPass(ngmlir::MLIRCompiler& compiler) DialectLoweringPass(ngmlir::MLIRCompiler& compiler)
: m_dialectLowerer(*this) : m_compiler(compiler)
, m_compiler(compiler)
{ {
} }
void runOnModule() override; void runOnModule() override;
SmallVector<Value*, 4> buildOutputDefs(Operation* op, PatternRewriter& rewriter); SmallVector<Value*, 4> buildOutputDefs(Operation* op, PatternRewriter& rewriter);
private: private:
/// Collect a set of patterns to convert from the nGraph dialect to Affine dialect.
void populateNGraphToAffineConversionPatterns(OwningRewritePatternList& patterns);
mlir::Function* getCallDecl(StringRef name, mlir::Function* getCallDecl(StringRef name,
ArrayRef<Type> args, ArrayRef<Type> args,
ArrayRef<Type> output, ArrayRef<Type> output,
...@@ -90,7 +96,7 @@ namespace ...@@ -90,7 +96,7 @@ namespace
Value* insertMemMgrDef(PatternRewriter* rewriter = nullptr); Value* insertMemMgrDef(PatternRewriter* rewriter = nullptr);
private: private:
DialectLowerer m_dialectLowerer; NGraphTypeConverter m_typeConverter;
// Value holding mem manager passed pointer // Value holding mem manager passed pointer
SmallVector<Value*, 4> m_memMgrDefs; SmallVector<Value*, 4> m_memMgrDefs;
...@@ -101,21 +107,39 @@ namespace ...@@ -101,21 +107,39 @@ namespace
void DialectLoweringPass::runOnModule() void DialectLoweringPass::runOnModule()
{ {
// Create type converter and initialize conversion patterns.
NGraphTypeConverter converter;
OwningRewritePatternList patterns;
populateNGraphToAffineConversionPatterns(patterns);
// Create target that defines legal ops for nGraph dialect to be lowered to.
ConversionTarget target(getContext());
// TODO: Remove NGFakeInputOp. We need to set NGFakeInputOp as legal op because we generate
// it as part of the lowering to affine/standard.
target.addLegalDialect<AffineOpsDialect, StandardOpsDialect>();
target.addLegalOp<NGFakeInputOp>();
// capture output values by looking for the Return and grabbing the values // capture output values by looking for the Return and grabbing the values
// the order of the returned values matches the order of the lowered func signature for // the order of the returned values matches the order of the lowered func signature for
// results. This is used to find the arg_id that a defined value maps to if it is an output // results. This is used to find the arg_id that a defined value maps to if it is an output
findOutputValues(); findOutputValues();
if (failed(m_dialectLowerer.convert(&getModule()))) if (failed(applyConversionPatterns(getModule(), target, converter, std::move(patterns))))
{ {
getModule().getContext()->emitError(mlir::UnknownLoc::get(getModule().getContext()), emitError(mlir::UnknownLoc::get(&getContext()), "Error lowering nGraph dialect\n");
"Error lowering dialect\n");
signalPassFailure(); signalPassFailure();
} }
processFakeInstrs(); processFakeInstrs();
} }
void DialectLoweringPass::populateNGraphToAffineConversionPatterns(
OwningRewritePatternList& patterns)
{
RewriteListBuilder<NGAddOpConversion, NGDotOpConversion, NGReturnOpConversion>::build(
patterns, &getContext(), *this);
}
void DialectLoweringPass::findOutputValues() void DialectLoweringPass::findOutputValues()
{ {
// get original function // get original function
...@@ -138,6 +162,9 @@ namespace ...@@ -138,6 +162,9 @@ namespace
outputCount = ret.getNumOperands(); outputCount = ret.getNumOperands();
}); });
// will be populated with lowered output values later // will be populated with lowered output values later
// TODO: This resize is making debugging obscure. When the container is not populated due
// to a bug, null pointers are used by the consumer leading to a crash more difficult to
// root-cause. We should try to change the current approach or introduce verification code.
m_loweredOutputValues.resize(outputCount, nullptr); m_loweredOutputValues.resize(outputCount, nullptr);
} }
...@@ -146,10 +173,11 @@ namespace ...@@ -146,10 +173,11 @@ namespace
{ {
// it would be nice to insert one fake def at the start of the new func // it would be nice to insert one fake def at the start of the new func
// however, due to how DialectConversion framework works, new func is only // however, due to how DialectConversion framework works, new func is only
// materialized after conversion is done (rewriter->getFunction, or even rewriter->getInsertionBlock()->getFunction() // materialized after conversion is done (rewriter->getFunction, or even
// will give you the original func). This makes it very convoluted to insert instructions at entry block. // rewriter->getInsertionBlock()->getFunction() will give you the original func). This
// makes it very convoluted to insert instructions at entry block.
auto op = rewriter->create<NGFakeInputOp>(rewriter->getUnknownLoc(), auto op = rewriter->create<NGFakeInputOp>(rewriter->getUnknownLoc(),
IndexType::get(getModule().getContext())); IndexType::get(&getContext()));
// will be fixed later to read passed arg instead. // will be fixed later to read passed arg instead.
m_memMgrDefs.push_back(op.getResult()); m_memMgrDefs.push_back(op.getResult());
return op.getResult(); return op.getResult();
...@@ -167,8 +195,7 @@ namespace ...@@ -167,8 +195,7 @@ namespace
unsigned argId = (int)attr.getInt(); unsigned argId = (int)attr.getInt();
auto fakeOp = rewriter.create<NGFakeInputOp>( auto fakeOp = rewriter.create<NGFakeInputOp>(
op->getLoc(), op->getLoc(),
m_dialectLowerer.convertType( m_typeConverter.convertType(origResult->getType()) /* convert to lowered type */
origResult->getType()) /* convert to lowered type */
); );
// Fake instrution is short-lived. Verify here. // Fake instrution is short-lived. Verify here.
fakeOp.verify(); fakeOp.verify();
...@@ -181,7 +208,7 @@ namespace ...@@ -181,7 +208,7 @@ namespace
auto tensorType = origResult->getType().cast<NGTensorType>(); auto tensorType = origResult->getType().cast<NGTensorType>();
auto callBackFunc = getCallDecl("__mlir_allocate", auto callBackFunc = getCallDecl("__mlir_allocate",
{rewriter.getIndexType(), rewriter.getIndexType()}, {rewriter.getIndexType(), rewriter.getIndexType()},
{m_dialectLowerer.convertType(tensorType)}, {m_typeConverter.convertType(tensorType)},
rewriter); rewriter);
auto size = tensorType.getSizeInBytes(); auto size = tensorType.getSizeInBytes();
...@@ -261,10 +288,10 @@ namespace ...@@ -261,10 +288,10 @@ namespace
return callBackFuncPtr; return callBackFuncPtr;
} }
// NGDialect converters // NGDialect converters
Type DialectLowerer::convertType(Type type) Type NGraphTypeConverter::convertType(Type type)
{ {
// We may need to refactor this code to a external utility if type conversion is needed // We may need to refactor this code to a external utility if type conversion is needed
// outside of the lowering context since DialectLowerer is private. // outside of the lowering context since NGraphTypeConverter is private.
if (auto tensor_type = type.dyn_cast<NGTensorType>()) if (auto tensor_type = type.dyn_cast<NGTensorType>())
{ {
...@@ -294,7 +321,7 @@ namespace ...@@ -294,7 +321,7 @@ namespace
} }
#define REWRITER(OP) \ #define REWRITER(OP) \
void OP##Conversion::rewrite( \ PatternMatchResult OP##Conversion::matchAndRewrite( \
Operation* op, ArrayRef<Value*> operands, PatternRewriter& rewriter) const Operation* op, ArrayRef<Value*> operands, PatternRewriter& rewriter) const
// ADD // ADD
...@@ -334,6 +361,8 @@ namespace ...@@ -334,6 +361,8 @@ namespace
}); });
// clang-format on // clang-format on
rewriter.replaceOp(op, {result}); rewriter.replaceOp(op, {result});
return matchSuccess();
} }
REWRITER(NGDotOp) REWRITER(NGDotOp)
...@@ -396,9 +425,16 @@ namespace ...@@ -396,9 +425,16 @@ namespace
}); });
rewriter.replaceOp(op, {result}); rewriter.replaceOp(op, {result});
return matchSuccess();
}
REWRITER(NGReturnOp)
{
rewriter.replaceOpWithNewOp<ReturnOp>(op);
return matchSuccess();
} }
REWRITER(NGReturnOp) { rewriter.replaceOpWithNewOp<ReturnOp>(op); }
#undef REWRITER #undef REWRITER
} }
......
...@@ -27,6 +27,8 @@ namespace ngraph ...@@ -27,6 +27,8 @@ namespace ngraph
namespace ngmlir namespace ngmlir
{ {
class MLIRCompiler; class MLIRCompiler;
using OwningRewritePatternList = std::vector<std::unique_ptr<mlir::RewritePattern>>;
} }
} }
} }
......
...@@ -17,17 +17,19 @@ ...@@ -17,17 +17,19 @@
// Add new dialect ops lowerers to this file // Add new dialect ops lowerers to this file
#define DECL_OP_CONV(OP) \ #define DECL_OP_CONV(OP) \
class OP##Conversion : public mlir::DialectConversionPattern \ class OP##Conversion : public NGraphOpLowering \
{\ { \
public:\ public: \
explicit OP##Conversion(mlir::MLIRContext *context, DialectLoweringPass& pass)\ explicit OP##Conversion(mlir::MLIRContext* context, DialectLoweringPass& pass) \
: mlir::DialectConversionPattern(mlir::OP::getOperationName(), 1, context),\ : NGraphOpLowering(mlir::OP::getOperationName(), context, pass) \
m_pass(pass)\ { \
{} \ } \
void rewrite(Operation *op, ArrayRef<Value *> operands, PatternRewriter &rewriter) const override; \ \
DialectLoweringPass& m_pass;\ PatternMatchResult matchAndRewrite(Operation* op, \
}; ArrayRef<Value*> operands, \
PatternRewriter& rewriter) const override; \
};
DECL_OP_CONV(NGAddOp) DECL_OP_CONV(NGAddOp)
DECL_OP_CONV(NGDotOp) DECL_OP_CONV(NGDotOp)
......
...@@ -48,16 +48,6 @@ namespace ngraph ...@@ -48,16 +48,6 @@ namespace ngraph
}; };
typedef EnumMask<FusionType> FusionTypeMask; typedef EnumMask<FusionType> FusionTypeMask;
// These constants are for backward compatibility only, will deprecate soon.
NGRAPH_DEPRECATED("use FusionType enum class instead")
constexpr FusionType DIFFERENTIABLE_FUSIONS = FusionType::DIFFERENTIABLE_FUSIONS;
NGRAPH_DEPRECATED("use FusionType enum class instead")
constexpr FusionType REGULAR_FUSIONS = FusionType::REGULAR_FUSIONS;
NGRAPH_DEPRECATED("use FusionType enum class instead")
constexpr FusionType FOP_FUSIONS = FusionType::FOP_FUSIONS;
NGRAPH_DEPRECATED("use FusionType enum class instead")
constexpr FusionType ALL_FUSIONS = FusionType::ALL_FUSIONS;
enum class PassProperty : uint32_t enum class PassProperty : uint32_t
{ {
// Pass requires node shapes to be static // Pass requires node shapes to be static
......
...@@ -217,8 +217,6 @@ void runtime::cpu::CPU_CallFrame::setup_runtime_context(Allocator* allocator) ...@@ -217,8 +217,6 @@ void runtime::cpu::CPU_CallFrame::setup_runtime_context(Allocator* allocator)
{ {
// single thread for codegen // single thread for codegen
NGRAPH_CHECK(m_num_ctx == 1); NGRAPH_CHECK(m_num_ctx == 1);
ctx->mkldnn_primitives.swap(mkldnn_emitter->get_mkldnn_primitives());
ctx->mkldnn_workspaces = mkldnn_emitter->get_mkldnn_workspaces();
} }
ctx->states = m_external_function->m_states.data(); ctx->states = m_external_function->m_states.data();
......
...@@ -33,7 +33,10 @@ ...@@ -33,7 +33,10 @@
#include "ngraph/codegen/execution_engine.hpp" #include "ngraph/codegen/execution_engine.hpp"
#endif #endif
#ifdef NGRAPH_MLIR_ENABLE
#include "contrib/mlir/pass/mlir_subgraph_extraction.hpp" #include "contrib/mlir/pass/mlir_subgraph_extraction.hpp"
#endif
#include "ngraph/descriptor/input.hpp" #include "ngraph/descriptor/input.hpp"
#include "ngraph/descriptor/output.hpp" #include "ngraph/descriptor/output.hpp"
#include "ngraph/file_util.hpp" #include "ngraph/file_util.hpp"
......
...@@ -57,28 +57,28 @@ namespace ngraph ...@@ -57,28 +57,28 @@ namespace ngraph
{ {
case 0 /*Logistic|Logistic*/: case 0 /*Logistic|Logistic*/:
{ {
auto c = (in0.exp() * in1.exp()) / ((in0.exp() + 1.f) * (in1.exp() + 1.f)); auto c = 1.f / (((-in0).exp() + 1.f) * ((-in1).exp() + 1.f));
out_tm.device( out_tm.device(
ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(arena)) = c; ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(arena)) = c;
} }
break; break;
case 1 /*Logistic|Tanh*/: case 1 /*Logistic|Tanh*/:
{ {
auto c = (in0.exp() * in1.tanh()) / (in0.exp() + 1.f); auto c = in1.tanh() / ((-in0).exp() + 1.f);
out_tm.device( out_tm.device(
ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(arena)) = c; ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(arena)) = c;
} }
break; break;
case 2 /*Logistic|Identity*/: case 2 /*Logistic|Identity*/:
{ {
auto c = (in0.exp() * in1) / (in0.exp() + 1.f); auto c = in1 / ((-in0).exp() + 1.f);
out_tm.device( out_tm.device(
ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(arena)) = c; ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(arena)) = c;
} }
break; break;
case 3 /*Tanh|Logistic*/: case 3 /*Tanh|Logistic*/:
{ {
auto c = (in0.tanh() * in1.exp()) / (in1.exp() + 1.f); auto c = in0.tanh() / ((-in1).exp() + 1.f);
out_tm.device( out_tm.device(
ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(arena)) = c; ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(arena)) = c;
} }
...@@ -99,7 +99,7 @@ namespace ngraph ...@@ -99,7 +99,7 @@ namespace ngraph
break; break;
case 6 /*Identity|Logistic*/: case 6 /*Identity|Logistic*/:
{ {
auto c = (in0 * in1.exp()) / (in1.exp() + 1.f); auto c = in0 / ((-in1).exp() + 1.f);
out_tm.device( out_tm.device(
ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(arena)) = c; ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(arena)) = c;
} }
...@@ -141,10 +141,15 @@ namespace ngraph ...@@ -141,10 +141,15 @@ namespace ngraph
{ {
case 0 /*Logistic|Logistic*/: case 0 /*Logistic|Logistic*/:
{ {
auto i0 = delta * (in1.exp() * in0.exp()) / auto in0_neg_exp = (-in0).exp();
((in1.exp() + 1.f) * ((in0.exp() + 1.f) * (in0.exp() + 1.f))); auto in0_log_denominator = in0_neg_exp + 1.f;
auto i1 = delta * (in0.exp() * in1.exp()) / auto in1_neg_exp = (-in1).exp();
((in0.exp() + 1.f) * ((in1.exp() + 1.f) * (in1.exp() + 1.f))); auto in1_log_denominator = in1_neg_exp + 1.f;
auto i0 = delta * in0_neg_exp /
(in1_log_denominator * in0_log_denominator * in0_log_denominator);
auto i1 = delta * in1_neg_exp /
(in0_log_denominator * in1_log_denominator * in1_log_denominator);
i0_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device( i0_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
arena)) = i0; arena)) = i0;
i1_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device( i1_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
...@@ -153,12 +158,17 @@ namespace ngraph ...@@ -153,12 +158,17 @@ namespace ngraph
break; break;
case 1 /*Logistic|Tanh*/: case 1 /*Logistic|Tanh*/:
{ {
auto in0_neg_exp = (-in0).exp();
auto in0_log_denominator = in0_neg_exp + 1.f;
auto in1_2exp = (in1 * 2.f).exp();
auto in1_tanh_denominator = in1_2exp + 1.f;
auto i0 = auto i0 =
delta * (((in1 * 2.f).exp() - 1.f) * in0.exp()) / delta * ((in1_2exp - 1.f) * in0_neg_exp) /
(((in1 * 2.f).exp() + 1.f) * ((in0.exp() + 1.f) * (in0.exp() + 1.f))); (in1_tanh_denominator * in0_log_denominator * in0_log_denominator);
auto i1 = delta * (in0.exp() * (4.f * (in1 * 2.f).exp())) / auto i1 =
((in0.exp() + 1.f) * delta * (4.f * in1_2exp) /
(((in1 * 2.f).exp() + 1.f) * ((in1 * 2.f).exp() + 1.f))); (in0_log_denominator * in1_tanh_denominator * in1_tanh_denominator);
i0_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device( i0_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
arena)) = i0; arena)) = i0;
i1_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device( i1_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
...@@ -167,9 +177,12 @@ namespace ngraph ...@@ -167,9 +177,12 @@ namespace ngraph
break; break;
case 2 /*Logistic|Identity*/: case 2 /*Logistic|Identity*/:
{ {
auto i0 = auto in0_neg_exp = (-in0).exp();
delta * (in1 * in0.exp()) / ((in0.exp() + 1.f) * (in0.exp() + 1.f)); auto in0_log_denominator = in0_neg_exp + 1.f;
auto i1 = delta * in0.exp() / ((in0.exp() + 1.f));
auto i0 = delta * (in1 * in0_neg_exp) /
(in0_log_denominator * in0_log_denominator);
auto i1 = delta / in0_log_denominator;
i0_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device( i0_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
arena)) = i0; arena)) = i0;
i1_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device( i1_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
...@@ -178,12 +191,17 @@ namespace ngraph ...@@ -178,12 +191,17 @@ namespace ngraph
break; break;
case 3 /*Tanh|Logistic*/: case 3 /*Tanh|Logistic*/:
{ {
auto i0 = delta * (in1.exp() * (4.f * (in0 * 2.f).exp())) / auto in0_2exp = (in0 * 2.f).exp();
((in1.exp() + 1.f) * ((in0 * 2.f).exp() + 1.f) * auto in0_tanh_denominator = in0_2exp + 1.f;
((in0 * 2.f).exp() + 1.f)); auto in1_neg_exp = (-in1).exp();
auto in1_log_denominator = in1_neg_exp + 1.f;
auto i0 =
delta * (4.f * in0_2exp) /
(in1_log_denominator * in0_tanh_denominator * in0_tanh_denominator);
auto i1 = auto i1 =
delta * (((in0 * 2.f).exp() - 1.f) * in1.exp()) / delta * ((in0_2exp - 1.f) * in1_neg_exp) /
(((in0 * 2.f).exp() + 1.f) * ((in1.exp() + 1.f) * (in1.exp() + 1.f))); (in0_tanh_denominator * in1_log_denominator * in1_log_denominator);
i0_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device( i0_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
arena)) = i0; arena)) = i0;
i1_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device( i1_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
...@@ -192,12 +210,17 @@ namespace ngraph ...@@ -192,12 +210,17 @@ namespace ngraph
break; break;
case 4 /*Tanh|Tanh*/: case 4 /*Tanh|Tanh*/:
{ {
auto i0 = delta * (((in1 * 2.f).exp() - 1.f) * (4.f * (in0 * 2.f).exp())) / auto in0_2exp = (in0 * 2.f).exp();
(((in1 * 2.f).exp() + 1.f) * auto in0_tanh_denominator = in0_2exp + 1.f;
(((in0 * 2.f).exp() + 1.f) * ((in0 * 2.f).exp() + 1.f))); auto in1_2exp = (in1 * 2.f).exp();
auto i1 = delta * (((in0 * 2.f).exp() - 1.f) * (4.f * (in1 * 2.f).exp())) / auto in1_tanh_denominator = in1_2exp + 1.f;
(((in0 * 2.f).exp() + 1.f) *
(((in1 * 2.f).exp() + 1.f) * ((in1 * 2.f).exp() + 1.f))); auto i0 =
delta * (in1_2exp - 1.f) * 4.f * in0_2exp /
(in1_tanh_denominator * in0_tanh_denominator * in0_tanh_denominator);
auto i1 =
delta * (in0_2exp - 1.f) * 4.f * in1_2exp /
(in0_tanh_denominator * in1_tanh_denominator * in1_tanh_denominator);
i0_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device( i0_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
arena)) = i0; arena)) = i0;
i1_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device( i1_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
...@@ -206,9 +229,12 @@ namespace ngraph ...@@ -206,9 +229,12 @@ namespace ngraph
break; break;
case 5 /*Tanh|Identity*/: case 5 /*Tanh|Identity*/:
{ {
auto i0 = delta * (in1 * (4.f * (in0 * 2.f).exp())) / auto in0_2exp = (in0 * 2.f).exp();
(((in0 * 2.f).exp() + 1.f) * ((in0 * 2.f).exp() + 1.f)); auto in0_tanh_denominator = in0_2exp + 1.f;
auto i1 = delta * ((in0 * 2.f).exp() - 1.f) / ((in0 * 2.f).exp() + 1.f);
auto i0 = delta * in1 * 4.f * in0_2exp /
(in0_tanh_denominator * in0_tanh_denominator);
auto i1 = delta * (in0_2exp - 1.f) / in0_tanh_denominator;
i0_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device( i0_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
arena)) = i0; arena)) = i0;
i1_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device( i1_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
...@@ -217,9 +243,12 @@ namespace ngraph ...@@ -217,9 +243,12 @@ namespace ngraph
break; break;
case 6 /*Identity|Logistic*/: case 6 /*Identity|Logistic*/:
{ {
auto i0 = delta * (in1.exp()) / (in1.exp() + 1.f); auto in1_neg_exp = (-in1).exp();
auto in1_log_denominator = in1_neg_exp + 1.f;
auto i0 = delta * 1.f / in1_log_denominator;
auto i1 = auto i1 =
delta * (in0 * in1.exp()) / ((in1.exp() + 1.f) * (in1.exp() + 1.f)); delta * in0 * in1_neg_exp / (in1_log_denominator * in1_log_denominator);
i0_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device( i0_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
arena)) = i0; arena)) = i0;
i1_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device( i1_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
...@@ -228,9 +257,12 @@ namespace ngraph ...@@ -228,9 +257,12 @@ namespace ngraph
break; break;
case 7 /*Identity|Tanh*/: case 7 /*Identity|Tanh*/:
{ {
auto i0 = delta * ((in1 * 2.f).exp() - 1.f) / ((in1 * 2.f).exp() + 1.f); auto in1_2exp = (in1 * 2.f).exp();
auto i1 = delta * (in0 * (4.f * (in1 * 2.f).exp())) / auto in1_tanh_denominator = in1_2exp + 1.f;
(((in1 * 2.f).exp() + 1.f) * ((in1 * 2.f).exp() + 1.f));
auto i0 = delta * (in1_2exp - 1.f) / in1_tanh_denominator;
auto i1 = delta * (in0 * (4.f * in1_2exp)) /
(in1_tanh_denominator * in1_tanh_denominator);
i0_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device( i0_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
arena)) = i0; arena)) = i0;
i1_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device( i1_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
......
...@@ -173,9 +173,6 @@ void runtime::gpu::GPUCompiledFunction::compile() ...@@ -173,9 +173,6 @@ void runtime::gpu::GPUCompiledFunction::compile()
pass_manager.register_pass<runtime::gpu::pass::BatchNormCache>(); pass_manager.register_pass<runtime::gpu::pass::BatchNormCache>();
pass_manager.register_pass<ngraph::pass::LikeReplacement>(); pass_manager.register_pass<ngraph::pass::LikeReplacement>();
pass_manager.register_pass<ngraph::pass::FusedOpDecomposition>(); pass_manager.register_pass<ngraph::pass::FusedOpDecomposition>();
// Run this pass for the second time since, some fused operators like LSTMCell may use
// other fused operators inside.
pass_manager.register_pass<ngraph::pass::FusedOpDecomposition>();
pass_manager.register_pass<ngraph::pass::ImplicitBroadcastElimination>(); pass_manager.register_pass<ngraph::pass::ImplicitBroadcastElimination>();
pass_manager.register_pass<runtime::gpu::pass::GPULayout>(this); pass_manager.register_pass<runtime::gpu::pass::GPULayout>(this);
pass_manager.register_pass<ngraph::pass::AssignLayout<descriptor::layout::DenseTensorLayout>>(); pass_manager.register_pass<ngraph::pass::AssignLayout<descriptor::layout::DenseTensorLayout>>();
......
...@@ -430,10 +430,6 @@ shared_ptr<runtime::Executable> ...@@ -430,10 +430,6 @@ shared_ptr<runtime::Executable>
if (m_disable_backend_optimizations < 2) if (m_disable_backend_optimizations < 2)
{ {
pass_manager.register_pass<ngraph::pass::FusedOpDecomposition>(
IntelGPUBackend::is_supported_impl);
// Run this pass for the second time since, some fused operators like LSTMCell may use
// other fused operators inside.
pass_manager.register_pass<ngraph::pass::FusedOpDecomposition>( pass_manager.register_pass<ngraph::pass::FusedOpDecomposition>(
IntelGPUBackend::is_supported_impl); IntelGPUBackend::is_supported_impl);
pass_manager.register_pass<ngraph::pass::ImplicitBroadcastElimination>(); pass_manager.register_pass<ngraph::pass::ImplicitBroadcastElimination>();
......
...@@ -47,9 +47,6 @@ runtime::interpreter::INTExecutable::INTExecutable(const shared_ptr<Function>& f ...@@ -47,9 +47,6 @@ runtime::interpreter::INTExecutable::INTExecutable(const shared_ptr<Function>& f
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<pass::LikeReplacement>(); pass_manager.register_pass<pass::LikeReplacement>();
pass_manager.register_pass<pass::FusedOpDecomposition>(); pass_manager.register_pass<pass::FusedOpDecomposition>();
// Run this pass for the second time since, some fused operators like LSTMCell may use
// other fused operators inside.
pass_manager.register_pass<pass::FusedOpDecomposition>();
pass_manager.register_pass<pass::ImplicitBroadcastElimination>(); pass_manager.register_pass<pass::ImplicitBroadcastElimination>();
pass_manager.register_pass<pass::AssignLayout<DenseTensorLayout>>(); pass_manager.register_pass<pass::AssignLayout<DenseTensorLayout>>();
pass_manager.register_pass<pass::Liveness>(); pass_manager.register_pass<pass::Liveness>();
......
...@@ -41,6 +41,7 @@ set(SRC ...@@ -41,6 +41,7 @@ set(SRC
plaidml_ops_one_hot.cpp plaidml_ops_one_hot.cpp
plaidml_ops_passthrough.cpp plaidml_ops_passthrough.cpp
plaidml_ops_pool.cpp plaidml_ops_pool.cpp
plaidml_ops_quantize.cpp
plaidml_ops_reduce.cpp plaidml_ops_reduce.cpp
plaidml_ops_replace_slice.cpp plaidml_ops_replace_slice.cpp
plaidml_ops_replicate.cpp plaidml_ops_replicate.cpp
......
...@@ -188,7 +188,8 @@ class ngraph::runtime::plaidml::builder::Elementwise final : public Statement ...@@ -188,7 +188,8 @@ class ngraph::runtime::plaidml::builder::Elementwise final : public Statement
{ {
public: public:
Elementwise(std::string lhs, std::string rhs); Elementwise(std::string lhs, std::string rhs);
void set_lhs(const std::string& lhs) { m_lhs = lhs; }
void set_rhs(const std::string& rhs) { m_rhs = rhs; }
private: private:
friend class Function; friend class Function;
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include "ngraph/pass/algebraic_simplification.hpp" #include "ngraph/pass/algebraic_simplification.hpp"
#include "ngraph/pass/core_fusion.hpp" #include "ngraph/pass/core_fusion.hpp"
#include "ngraph/pass/cse.hpp" #include "ngraph/pass/cse.hpp"
#include "ngraph/pass/fused_op_decomposition.hpp"
#include "ngraph/pass/get_output_element_elimination.hpp" #include "ngraph/pass/get_output_element_elimination.hpp"
#include "ngraph/pass/like_replacement.hpp" #include "ngraph/pass/like_replacement.hpp"
#include "ngraph/pass/liveness.hpp" #include "ngraph/pass/liveness.hpp"
...@@ -87,6 +88,7 @@ std::shared_ptr<ngraph::runtime::plaidml::PlaidML_Executable> ...@@ -87,6 +88,7 @@ std::shared_ptr<ngraph::runtime::plaidml::PlaidML_Executable>
ngraph::pass::Manager pass_manager; ngraph::pass::Manager pass_manager;
// We apply the same general-purposes passes as the CPU backend. // We apply the same general-purposes passes as the CPU backend.
pass_manager.register_pass<ngraph::pass::FusedOpDecomposition>();
pass_manager.register_pass<ngraph::pass::LikeReplacement>(); pass_manager.register_pass<ngraph::pass::LikeReplacement>();
pass_manager.register_pass<ngraph::pass::NopElimination>(); pass_manager.register_pass<ngraph::pass::NopElimination>();
pass_manager.register_pass<ngraph::pass::ZeroDimTensorElimination>(); pass_manager.register_pass<ngraph::pass::ZeroDimTensorElimination>();
......
...@@ -242,6 +242,8 @@ ngraph::runtime::plaidml::Config ...@@ -242,6 +242,8 @@ ngraph::runtime::plaidml::Config
} }
// Reject unknown options // Reject unknown options
NGRAPH_ERR << "Unrecognized PlaidML backend option: "
<< std::string{oname_begin, static_cast<std::size_t>(oname_end - oname_begin)};
err = true; err = true;
} }
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/dequantize.hpp"
#include "ngraph/op/quantize.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
namespace ngraph
{
namespace runtime
{
namespace plaidml
{
NGRAPH_PLAIDML_OP_CLASS(ImplDequantize, OpImpl<op::Dequantize>);
NGRAPH_PLAIDML_OP_CLASS(ImplQuantize, OpImpl<op::Quantize>);
}
}
}
void ngraph::runtime::plaidml::ImplDequantize::Apply()
{
check_inputs(3);
check_outputs(1);
const auto& axes = op().get_axes();
const auto& input_shape = op().get_input_shape(0);
const auto& scale_shape = op().get_input_shape(1);
const auto& zp_shape = op().get_input_shape(2);
const auto& input_type = op().get_input_element_type(0);
if (!input_type.is_signed() && input_type.size() >= 8)
{
throw std::runtime_error("PlaidML does not yet support dequantizing from uint64+");
}
if (scale_shape != zp_shape)
{
throw std::runtime_error("Dequantize given mismatched scale & zero point shapes.");
}
if (scale_shape.size() != axes.size())
{
std::ostringstream msg;
msg << "Dequantize received " << axes.size()
<< " axes to use for scale & zero point, but those tensors have " << scale_shape.size()
<< " dimensions instead.";
throw std::runtime_error(msg.str());
}
std::vector<std::string> short_idxs;
for (size_t i = 0; i < input_shape.size(); ++i)
{
if (axes.count(i))
{
std::ostringstream name;
name << "i" << i;
short_idxs.push_back(name.str());
}
}
builder::ContractionInput scale_input{"S"};
builder::ContractionInput neg_zp_input{"NegZ"};
for (const auto& idx : short_idxs)
{
scale_input.add_indices({idx});
neg_zp_input.add_indices({idx});
}
std::function<std::string(std::string)> cast_uint_to_wider_int =
[input_type](std::string tensor_name) {
std::ostringstream cast_str;
if (!input_type.is_signed())
{
cast_str << "as_int(" << tensor_name << ", " << 2 * 8 * input_type.size() << ")";
}
else
{
cast_str << tensor_name;
}
return cast_str.str();
};
builder::Elementwise CastI{"CastI", cast_uint_to_wider_int("I")};
builder::Elementwise CastZ{"CastZ", cast_uint_to_wider_int("Z")};
auto f = start_tile_function();
f.add(builder::Input{op_input(0), "I"}.add_dims("I", 0, input_shape.size()))
.add(builder::Input{op_input(1), "S"}.add_dims("S", 0, scale_shape.size()))
.add(builder::Input{op_input(2), "Z"}.add_dims("Z", 0, zp_shape.size()))
.add(builder::Output{"O"})
.add(CastI)
.add(CastZ)
.add(builder::Elementwise{"NegZ", "-CastZ"})
.add(
builder::BinaryContraction{"=", "+"}
.set(builder::ContractionOutput{"Offset"}
.add_indices("i", 0, input_shape.size())
.add_dims("I", 0, input_shape.size()))
.set_lhs(builder::ContractionInput{"CastI"}.add_indices("i", 0, input_shape.size()))
.set_rhs(neg_zp_input))
.add(builder::BinaryContraction{"=", "*"}
.set(builder::ContractionOutput{"O"}
.add_indices("i", 0, input_shape.size())
.add_dims("I", 0, input_shape.size()))
.set_lhs(
builder::ContractionInput{"Offset"}.add_indices("i", 0, input_shape.size()))
.set_rhs(scale_input));
set_output(f.finalize());
}
void ngraph::runtime::plaidml::ImplQuantize::Apply()
{
check_inputs(3);
check_outputs(1);
const auto& type = op().get_output_element_type(0);
const auto& axes = op().get_axes();
const auto& round_mode = op().get_round_mode();
const auto& input_shape = op().get_input_shape(0);
const auto& scale_shape = op().get_input_shape(1);
const auto& zp_shape = op().get_input_shape(2);
std::function<std::string(std::string)> cast_to_output_type = [type](std::string tensor_name) {
std::ostringstream cast_str;
if (type.is_signed())
{
cast_str << "as_int";
}
else
{
cast_str << "as_uint";
}
cast_str << "(" << tensor_name << ", " << 8 * type.size() << ")";
return cast_str.str();
};
if (scale_shape != zp_shape)
{
throw std::runtime_error("Quantize given mismatched scale & zero point shapes.");
}
if (scale_shape.size() != axes.size())
{
std::ostringstream msg;
msg << "Quantize received " << axes.size()
<< " axes to use for scale & zero point, but those tensors have " << scale_shape.size()
<< " dimensions instead.";
throw std::runtime_error(msg.str());
}
std::vector<std::string> short_idxs;
for (size_t i = 0; i < input_shape.size(); ++i)
{
if (axes.count(i))
{
std::ostringstream name;
name << "i" << i;
short_idxs.push_back(name.str());
}
}
if (!type.is_integral())
{
throw std::runtime_error("Quantize output type must be integral");
}
builder::Elementwise Rounded{"Rounded", ""};
builder::Elementwise Clamped{"Clamped", ""};
builder::Elementwise O{"O", ""};
int64_t q_min;
int64_t q_max;
std::ostringstream clamp_formula;
if (type.size() > 4)
{
// PlaidML doesn't support quantization clamping for types wider than 32 bits
if (!type.is_signed())
{
clamp_formula << "Uncast < 0 ? 0 : Uncast";
}
else
{
clamp_formula << "Uncast";
}
}
else
{
if (type.is_signed())
{
q_max = (1 << (8 * type.size() - 1)) - 1;
q_min = -q_max - 1;
}
else
{
q_max = (1 << (8 * type.size())) - 1;
q_min = 0;
}
clamp_formula << "Uncast < " << q_min << " ? " << q_min << " : "
<< "(Uncast > " << q_max << " ? " << q_max << " : Uncast)";
}
Clamped.set_rhs(clamp_formula.str());
std::ostringstream round_formula;
std::string lower_rounded_int;
switch (round_mode)
{
case ngraph::op::Quantize::RoundMode::ROUND_DOWN: Rounded.set_rhs("floor(Frac)"); break;
case ngraph::op::Quantize::RoundMode::ROUND_UP: Rounded.set_rhs("ceil(Frac)"); break;
case ngraph::op::Quantize::RoundMode::ROUND_NEAREST_DOWNWARD:
Rounded.set_rhs("ceil(Frac - 0.5)");
break;
case ngraph::op::Quantize::RoundMode::ROUND_NEAREST_UPWARD:
Rounded.set_rhs("floor(Frac + 0.5)");
break;
case ngraph::op::Quantize::RoundMode::ROUND_TOWARD_ZERO:
Rounded.set_rhs("Frac > 0 ? floor(Frac) : ceil(Frac)");
break;
case ngraph::op::Quantize::RoundMode::ROUND_TOWARD_INFINITY:
Rounded.set_rhs("Frac < 0 ? floor(Frac) : ceil(Frac)");
break;
case ngraph::op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_ZERO:
Rounded.set_rhs("Frac > 0 ? ceil(Frac - 0.5) : floor(Frac + 0.5)");
break;
case ngraph::op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_INFINITY:
Rounded.set_rhs("Frac < 0 ? ceil(Frac - 0.5) : floor(Frac + 0.5)");
break;
case ngraph::op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN:
// This is ugly, but it produces correct output
lower_rounded_int = cast_to_output_type("ceil(Frac - 0.5)");
round_formula << "2 * (" << lower_rounded_int << " / 2) == " << lower_rounded_int
<< " ? ceil(Frac - 0.5) : floor(Frac + 0.5)";
Rounded.set_rhs(round_formula.str());
break;
default:
throw std::runtime_error("Requested quantize round mode not yet implemented in PlaidML");
}
O.set_rhs(cast_to_output_type("Clamped"));
builder::ContractionInput scale_recip_input{"SRecip"};
builder::ContractionInput zp_input{"Z"};
for (const auto& idx : short_idxs)
{
scale_recip_input.add_indices({idx});
zp_input.add_indices({idx});
}
auto f = start_tile_function();
f.add(builder::Input{op_input(0), "I"}.add_dims("I", 0, input_shape.size()))
.add(builder::Input{op_input(1), "S"}.add_dims("S", 0, scale_shape.size()))
.add(builder::Input{op_input(2), "Z"}.add_dims("Z", 0, zp_shape.size()))
.add(builder::Output{"O"})
.add(builder::Elementwise{"SRecip", "1 / S"})
.add(builder::BinaryContraction{"=", "*"}
.set(builder::ContractionOutput{"Frac"}
.add_indices("i", 0, input_shape.size())
.add_dims("I", 0, input_shape.size()))
.set_lhs(builder::ContractionInput{"I"}.add_indices("i", 0, input_shape.size()))
.set_rhs(scale_recip_input))
.add(Rounded)
.add(builder::BinaryContraction{"=", "+"}
.set(builder::ContractionOutput{"Uncast"}
.add_indices("i", 0, input_shape.size())
.add_dims("I", 0, input_shape.size()))
.set_lhs(
builder::ContractionInput{"Rounded"}.add_indices("i", 0, input_shape.size()))
.set_rhs(zp_input))
.add(Clamped)
.add(O);
set_output(f.finalize());
}
...@@ -60,37 +60,7 @@ generate_mask ...@@ -60,37 +60,7 @@ generate_mask
generate_mask2 generate_mask2
avg_pool_3d avg_pool_3d
avg_pool_3d_uneven_strided_padded_include_in_computation avg_pool_3d_uneven_strided_padded_include_in_computation
quantize_dynamic_offset # Quantization/Dequantization is unimplemented quantize_clamp_int32 # Requires fp64 inputs, which won't work on GPUs
dequantize_dynamic_offset # Quantization/Dequantization is unimplemented
dequantize_int8_zero_offset # Quantization/Dequantization is unimplemented
dequantize_int32 # Quantization/Dequantization is unimplemented
dequantize_int32_zero_offset # Quantization/Dequantization is unimplemented
dequantize_zero_offset # Quantization/Dequantization is unimplemented
quantize_ROUND_NEAREST_TOWARD_ZERO # Quantization/Dequantization is unimplemented
quantize_ROUND_NEAREST_UPWARD # Quantization/Dequantization is unimplemented
quantize_ROUND_NEAREST_DOWNWARD # Quantization/Dequantization is unimplemented
quantize_ROUND_NEAREST_TOWARD_EVEN # Quantization/Dequantization is unimplemented
quantize_ROUND_NEAREST_TOWARD_INFINITY # Quantization/Dequantization is unimplemented
quantize_ROUND_TOWARD_INFINITY # Quantization/Dequantization is unimplemented
quantize_ROUND_TOWARD_ZERO # Quantization/Dequantization is unimplemented
quantize_ROUND_UP # Quantization/Dequantization is unimplemented
quantize_ROUND_DOWN # Quantization/Dequantization is unimplemented
quantize # Quantization/Dequantization is unimplemented
quantize_zero_offset # Quantization/Dequantization is unimplemented
quantize_axes # Quantization/Dequantization is unimplemented
quantize_dynamic_offset # Quantization/Dequantization is unimplemented
quantize_int8 # Quantization/Dequantization is unimplemented
quantize_int8_zero_offset # Quantization/Dequantization is unimplemented
quantize_int32 # Quantization/Dequantization is unimplemented
quantize_int32_zero_offset # Quantization/Dequantization is unimplemented
quantize_clamp # Quantization/Dequantization is unimplemented
quantize_clamp_int8 # Quantization/Dequantization is unimplemented
quantize_clamp_int32 # Quantization/Dequantization is unimplemented
quantize_clamp_int32_zero_offset # Quantization/Dequantization is unimplemented
quantize_clamp_uint8 # Quantization/Dequantization is unimplemented
dequantize # Quantization/Dequantization is unimplemented
dequantize_axes # Quantization/Dequantization is unimplemented
dequantize_int8 # Quantization/Dequantization is unimplemented
numeric_float_nan numeric_float_nan
numeric_double_nan numeric_double_nan
shape_of_scalar shape_of_scalar
...@@ -259,12 +229,6 @@ backwards_softmax_underflow ...@@ -259,12 +229,6 @@ backwards_softmax_underflow
backwards_softmax_3d backwards_softmax_3d
batch_mat_mul_forward batch_mat_mul_forward
dot_matrix_2x0_0x2 dot_matrix_2x0_0x2
rnn_cell_no_bias
rnn_cell_bias_clip
rnn_cell_activation_function
gru_cell_bias_clip
gru_cell_linear_before_reset
gru_cell_activation_function
# dgkutnic ww24.5: these tests are to be triaged by the PlaidML team # dgkutnic ww24.5: these tests are to be triaged by the PlaidML team
# ww25.2: re-scrubbed this list of tests after fixing check_inputs # ww25.2: re-scrubbed this list of tests after fixing check_inputs
...@@ -289,3 +253,29 @@ group_conv_transpose ...@@ -289,3 +253,29 @@ group_conv_transpose
group_conv_transpose_output_shape group_conv_transpose_output_shape
divide_python_rounding_int32 divide_python_rounding_int32
backwards_batchmatmul_tensor2_tensor2 backwards_batchmatmul_tensor2_tensor2
# unsupported ops: 'QuantizedConvolution', 'QuantizedDot', 'TopK', 'Erf', 'EmbeddingLookup'
model_quant_conv_linear
model_conv_integer_no_zero_point
model_matmul_integer_no_zero_point
model_matmul_integer_4d_no_zero_point
model_top_k
model_erf
model_erf_int32
model_hardmax
# node validation error: "Argument shapes are inconsistent."
model_lstm_fwd_with_clip
model_lstm_fwd_mixed_seq
model_lstm_fwd_hardsigmoid_activation
model_reduce_log_sum
model_reduce_log_sum_exp
model_reduce_mean
# result mismatch
model_dequantize_linear_scalar_zero_scale_int8
model_softmax
avg_pool_3d_uneven_strided_padded
rnn_cell_activation_function
gru_cell_bias_clip
gru_cell_linear_before_reset
ir_version: 4
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "X"
input: "W"
input: "R"
output: ""
output: "Y_h"
op_type: "LSTM"
attribute {
name: "clip"
f: 9999.0
type: FLOAT
}
attribute {
name: "direction"
s: "forward"
type: STRING
}
attribute {
name: "hidden_size"
i: 3
type: INT
}
}
name: "compute_graph"
input {
name: "X"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 32
}
dim {
dim_value: 1
}
}
}
}
}
input {
name: "W"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 12
}
dim {
dim_value: 1
}
}
}
}
}
input {
name: "R"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 12
}
dim {
dim_value: 3
}
}
}
}
}
output {
name: "Y_h"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 32
}
dim {
dim_value: 3
}
}
}
}
}
}
opset_import {
version: 7
}
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include <fstream> #include <fstream>
#include <iterator> #include <iterator>
#include <limits> #include <limits>
#include <numeric>
#include <sstream> #include <sstream>
#include <stdexcept> #include <stdexcept>
#include <vector> #include <vector>
...@@ -203,3 +204,48 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_hardsigmoid_activation) ...@@ -203,3 +204,48 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_hardsigmoid_activation)
test_case.set_tolerance(6); test_case.set_tolerance(6);
test_case.run(); test_case.run();
} }
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_large_batch_no_clip)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/lstm_fwd_large_batch_no_clip.prototxt"));
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
std::size_t seq_length = 2;
std::size_t batch_size = 32;
std::size_t input_size = 1;
std::size_t hidden_size = 3;
std::vector<float> in_X(seq_length * batch_size * input_size);
std::iota(std::begin(in_X), std::end(in_X), 1.f);
std::vector<float> in_R(4 * hidden_size * hidden_size, 0.1f);
// X
test_case.add_input<float>(in_X);
// W
test_case.add_input<float>(
{0.1f, 0.2f, 0.3f, 0.4f, 1.f, 2.f, 3.f, 4.f, 10.f, 11.f, 12.f, 13.f});
// R
test_case.add_input<float>(in_R);
// Y_h_data
test_case.add_expected_output<float>(
Shape{1, batch_size, hidden_size},
{0.90387899f, 0.9135572f, 0.91772245f, 0.90897038f, 0.92132433f, 0.92825467f, 0.91365823f,
0.92815113f, 0.93676105f, 0.91799162f, 0.93406357f, 0.94344562f, 0.92199681f, 0.93912057f,
0.94859476f, 0.92569357f, 0.94340185f, 0.95250664f, 0.92909964f, 0.94699686f, 0.95545127f,
0.93223207f, 0.94999634f, 0.95765468f, 0.93510761f, 0.9524867f, 0.95929726f, 0.93774272f,
0.9545467f, 0.96051891f, 0.9401536f, 0.95624603f, 0.96142619f, 0.94235605f, 0.95764499f,
0.96209939f, 0.94436539f, 0.95879495f, 0.96259862f, 0.94619635f, 0.95973921f, 0.96296872f,
0.94786299f, 0.96051397f, 0.96324302f, 0.94937864f, 0.96114929f, 0.96344629f, 0.95075587f,
0.96167006f, 0.96359692f, 0.95200645f, 0.96209679f, 0.96370852f, 0.95314133f, 0.9624464f,
0.9637912f, 0.95417069f, 0.96273278f, 0.96385246f, 0.95510395f, 0.96296733f, 0.96389785f,
0.95594975f, 0.96315942f, 0.96393147f, 0.95671607f, 0.96331673f, 0.96395638f, 0.9574102f,
0.96344554f, 0.96397483f, 0.9580388f, 0.96355102f, 0.9639885f, 0.95860795f, 0.96363739f,
0.96399863f, 0.95912322f, 0.96370811f, 0.96400613f, 0.95958963f, 0.96376601f, 0.96401169f,
0.96001179f, 0.96381342f, 0.96401581f, 0.96039386f, 0.96385224f, 0.96401886f, 0.96073964f,
0.96388402f, 0.96402112f, 0.96105254f, 0.96391004f, 0.96402279f});
test_case.run();
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment