Unverified Commit 41e55abf authored by ECouzens's avatar ECouzens Committed by GitHub

Merge branch 'master' into resnet-readme-doc

parents 20666f59 00125a23
......@@ -11,7 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
cmake_minimum_required (VERSION 2.8)
cmake_minimum_required (VERSION 3.1)
set(NGRAPH_INCLUDE_PATH
${CMAKE_CURRENT_SOURCE_DIR}/src
......
# Copyright 2017 Nervana Systems Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
include(ExternalProject)
if(NGRAPH_CPU_ENABLE AND (NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin") AND
(NOT ${CMAKE_SYSTEM_NAME} MATCHES "Windows"))
message(STATUS "Fetching LLVM from llvm.org")
set(LLVM_RELEASE_URL http://releases.llvm.org/5.0.0/clang+llvm-5.0.0-linux-x86_64-ubuntu16.04.tar.xz)
set(LLVM_SHA1_HASH 9cb81c92aa4d3f9707a9b8413c4d24b8dee90c59)
# Override default LLVM binaries
if(PREBUILT_LLVM)
if(NOT DEFINED PREBUILT_LLVM_HASH)
message(FATAL_ERROR "SHA1 hash of prebuilt llvm tarball not provided in PREBUILT_LLVM_HASH.")
endif()
set(LLVM_RELEASE_URL ${PREBUILT_LLVM})
set(LLVM_SHA1_HASH ${PREBUILT_LLVM_HASH})
endif()
# The 'BUILD_BYPRODUCTS' argument was introduced in CMake 3.2.
if(${CMAKE_VERSION} VERSION_LESS 3.2)
ExternalProject_Add(
ext_llvm
URL ${LLVM_RELEASE_URL}
URL_HASH SHA1=${LLVM_SHA1_HASH}
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ""
UPDATE_COMMAND ""
)
else()
ExternalProject_Add(
ext_llvm
URL ${LLVM_RELEASE_URL}
URL_HASH SHA1=${LLVM_SHA1_HASH}
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ""
UPDATE_COMMAND ""
BUILD_BYPRODUCTS "${CMAKE_CURRENT_BINARY_DIR}/ext_llvm-prefix/src/ext_llvm/lib/libLLVMCore.a"
)
endif()
ExternalProject_Get_Property(ext_llvm source_dir)
set(LLVM_INCLUDE_DIR "${source_dir}/include" PARENT_SCOPE)
set(LLVM_LIB_DIR "${source_dir}/lib" PARENT_SCOPE)
set(LLVM_LINK_LIBS clangTooling clangFrontendTool clangFrontend clangDriver clangSerialization clangCodeGen clangParse clangSema clangStaticAnalyzerFrontend clangStaticAnalyzerCheckers clangStaticAnalyzerCore clangAnalysis clangARCMigrate clangRewriteFrontend clangEdit clangAST clangLex clangBasic LLVMLTO LLVMPasses LLVMObjCARCOpts LLVMSymbolize LLVMDebugInfoPDB LLVMDebugInfoDWARF LLVMMIRParser LLVMCoverage LLVMTableGen LLVMDlltoolDriver LLVMOrcJIT LLVMXCoreDisassembler LLVMXCoreCodeGen LLVMXCoreDesc LLVMXCoreInfo LLVMXCoreAsmPrinter LLVMSystemZDisassembler LLVMSystemZCodeGen LLVMSystemZAsmParser LLVMSystemZDesc LLVMSystemZInfo LLVMSystemZAsmPrinter LLVMSparcDisassembler LLVMSparcCodeGen LLVMSparcAsmParser LLVMSparcDesc LLVMSparcInfo LLVMSparcAsmPrinter LLVMPowerPCDisassembler LLVMPowerPCCodeGen LLVMPowerPCAsmParser LLVMPowerPCDesc LLVMPowerPCInfo LLVMPowerPCAsmPrinter LLVMNVPTXCodeGen LLVMNVPTXDesc LLVMNVPTXInfo LLVMNVPTXAsmPrinter LLVMMSP430CodeGen LLVMMSP430Desc LLVMMSP430Info LLVMMSP430AsmPrinter LLVMMipsDisassembler LLVMMipsCodeGen LLVMMipsAsmParser LLVMMipsDesc LLVMMipsInfo LLVMMipsAsmPrinter LLVMLanaiDisassembler LLVMLanaiCodeGen LLVMLanaiAsmParser LLVMLanaiDesc LLVMLanaiAsmPrinter LLVMLanaiInfo LLVMHexagonDisassembler LLVMHexagonCodeGen LLVMHexagonAsmParser LLVMHexagonDesc LLVMHexagonInfo LLVMBPFDisassembler LLVMBPFCodeGen LLVMBPFDesc LLVMBPFInfo LLVMBPFAsmPrinter LLVMARMDisassembler LLVMARMCodeGen LLVMARMAsmParser LLVMARMDesc LLVMARMInfo LLVMARMAsmPrinter LLVMAMDGPUDisassembler LLVMAMDGPUCodeGen LLVMAMDGPUAsmParser LLVMAMDGPUDesc LLVMAMDGPUInfo LLVMAMDGPUAsmPrinter LLVMAMDGPUUtils LLVMAArch64Disassembler LLVMAArch64CodeGen LLVMAArch64AsmParser LLVMAArch64Desc LLVMAArch64Info LLVMAArch64AsmPrinter LLVMAArch64Utils LLVMObjectYAML LLVMLibDriver LLVMOption LLVMX86Disassembler LLVMX86AsmParser LLVMX86CodeGen LLVMGlobalISel LLVMSelectionDAG LLVMAsmPrinter LLVMDebugInfoCodeView LLVMDebugInfoMSF LLVMX86Desc LLVMMCDisassembler LLVMX86Info LLVMX86AsmPrinter LLVMX86Utils LLVMMCJIT LLVMLineEditor LLVMInterpreter LLVMExecutionEngine LLVMRuntimeDyld LLVMCodeGen LLVMTarget LLVMCoroutines LLVMipo LLVMInstrumentation LLVMVectorize LLVMScalarOpts LLVMLinker LLVMIRReader LLVMAsmParser LLVMInstCombine LLVMTransformUtils LLVMBitWriter LLVMAnalysis LLVMProfileData LLVMObject LLVMMCParser LLVMMC LLVMBitReader LLVMCore LLVMBinaryFormat LLVMSupport LLVMDemangle tinfo z m PARENT_SCOPE)
endif()
......@@ -74,6 +74,7 @@ check_cpu: build_ngraph_cpp_cpu
docker run --rm --tty \
${VOLUME} \
${DOCKER_RUN_ENV} \
--env GTEST_OUTPUT="xml:${DOCKUSER_HOME}/ngraph-cpp-test/BUILD/unit-test-results.xml" \
--env RUN_UID="$(shell id -u)" \
--env RUN_CMD="set -e ; set -o pipefail ; cd ${DOCKUSER_HOME}/ngraph-cpp-test/BUILD; cmake -DCMAKE_CXX_COMPILER=clang++-3.9 -DCMAKE_C_COMPILER=clang-3.9 -DNGRAPH_BUILD_DOXYGEN_DOCS=ON -DNGRAPH_BUILD_SPHINX_DOCS=ON .. 2>&1 | tee cmake.log ; env VERBOSE=1 make ${PARALLEL} 2>&1 | tee make.log ; env VERBOSE=1 make check 2>&1 | tee make_check.log" \
"ngraph_cpp_cpu:${BUILD_VERSION}" \
......
# Copyright 2017 Nervana Systems Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
NGRAPH_DIST_DIR = ${HOME}/ngraph_dist
CXXFLAGS += -std=c++11
CPPFLAGS += -I $(NGRAPH_DIST_DIR)
LDFLAGS = -L $(NGRAPH_DIST_DIR)
OBJ = main.o
%.o: %.cpp $(DEPS)
$(CXX) -c -o $@ $< $(CXXFLAGS) $(CPPFLAGS)
ngraph-test: $(OBJ)
$(CXX) -o $@ $(OBJ) $(LDFLAGS) -lngraph
.PHONY: clean
clean:
rm -f $(OBJ) ngraph-test
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <stdio.h>
#include "ngraph/ngraph.hpp"
#include "ngraph/ops/dot.hpp"
using namespace std;
using namespace ngraph;
int main(int argc, char** argv)
{
printf( "Building graph\n" );
// Function with 4 parameters
auto arg0 = op::parameter(element::Float::type, {7, 3});
auto arg1 = op::parameter(element::Float::type, {3});
auto arg2 = op::parameter(element::Float::type, {32, 7});
auto arg3 = op::parameter(element::Float::type, {32, 7});
auto broadcast_1 = op::broadcast(arg3, {10, 32, 7}, {0});
auto dot = op::dot(arg2, arg0);
auto cluster_0 = op::function(dot, {arg0, arg1, arg2, arg3});
auto result = cluster_0->result();
printf( "Finished\n" );
}
\ No newline at end of file
......@@ -13,9 +13,6 @@
set (SRC
autodiff/adjoints.cpp
autodiff/backprop_derivative.cpp
autodiff/backprop_function.cpp
autodiff/numeric_derivative.cpp
descriptor/input.cpp
descriptor/layout/dense_tensor_view_layout.cpp
descriptor/layout/tensor_view_layout.cpp
......@@ -29,8 +26,8 @@ set (SRC
node.cpp
ops/add.cpp
ops/binary_elementwise_arithmetic.cpp
ops/binary_elementwise_builtin.cpp
ops/binary_elementwise_comparison.cpp
ops/binary_elementwise.cpp
ops/broadcast.cpp
ops/concatenate.cpp
ops/constant.cpp
......@@ -55,8 +52,7 @@ set (SRC
ops/sum.cpp
ops/tuple.cpp
ops/unary_elementwise_arithmetic.cpp
ops/unary_elementwise_builtin.cpp
pass/assign_tensors.cpp
ops/unary_elementwise.cpp
pass/collect_functions.cpp
pass/dump_sorted.cpp
pass/liveness.cpp
......@@ -65,7 +61,6 @@ set (SRC
pass/memory_layout.cpp
pass/memory_visualize.cpp
pass/pass.cpp
pass/propagate_types.cpp
pass/topological_sort.cpp
pass/visualize_tree.cpp
runtime/backend.cpp
......@@ -99,17 +94,77 @@ include_directories(
"${EIGEN_INCLUDE_DIR}"
)
if (NGRAPH_CPU_ENABLE AND LLVM_INCLUDE_DIR AND
MKLDNN_INCLUDE_DIR)
find_package(ZLIB REQUIRED)
include_directories(SYSTEM ${LLVM_INCLUDE_DIR} ${MKLDNN_INCLUDE_DIR})
link_directories(${LLVM_LIB_DIR} ${MKLDNN_LIB_DIR})
# Add sources for the CPU backend
# and all its dependencies
set(SRC ${SRC}
codegen/compiler.cpp
runtime/cpu/call_frame.cpp
runtime/cpu/cpu_backend.cpp
runtime/cpu/cpu_manager.cpp
runtime/cpu/cpu_kernels.cpp
runtime/cpu/emitter.cpp
runtime/cpu/external_function.cpp
)
# LLVM binary builds are typically built without RTTI
# The built-in headers are in a version-specific directory
# This must be kept in sync with the LLVM + Clang version in use
set_source_files_properties(codegen/compiler.cpp PROPERTIES COMPILE_FLAGS "-fno-rtti")
set_source_files_properties(codegen/compiler.cpp PROPERTIES COMPILE_DEFINITIONS
"EIGEN_HEADERS_PATH=\"${EIGEN_INCLUDE_DIR}\";CLANG_BUILTIN_HEADERS_PATH=\"${LLVM_LIB_DIR}/clang/5.0.0/include\";NGRAPH_HEADERS_PATH=\"${NGRAPH_INCLUDE_PATH}\"")
set(NGRAPH_CPU_PCH_ENABLE 0 CACHE STRING "Enable pre-compiled headers in the CPU backend")
set(NGRAPH_CPU_DEBUGINFO_ENABLE 0 CACHE STRING "Enable debuginfo in the CPU backend")
set_source_files_properties(runtime/cpu/external_function.cpp PROPERTIES COMPILE_DEFINITIONS
"NGCPU_PCH=${NGRAPH_CPU_PCH_ENABLE};NGCPU_DEBUGINFO=${NGRAPH_CPU_DEBUGINFO_ENABLE}")
endif()
add_library(ngraph SHARED ${SRC})
# Colon separated string for specified runtime plugin loading, this is made explicit s.t. if a
# plugin is specified at compile time but the corresponding library could not be resolved at run-
# time, an error will be generated.
# E.g. assume compiling with Argon and Xpu, then -DRUNTIME_PLUGIN_LIBS="libargon.so:libxpu.so".
if (DEFINED RUNTIME_PLUGIN_LIBS)
target_compile_definitions(ngraph PRIVATE RUNTIME_PLUGIN_LIBS=${RUNTIME_PLUGIN_LIBS})
else()
target_compile_definitions(ngraph PRIVATE RUNTIME_PLUGIN_LIBS="")
endif()
# This is used to ensure that libngraph.so and libargon.so are in the same directory for dlopen.
# Effective at build time. Does not affect `make install` logics.
if (DEFINED COMMON_LIBRARY_OUTPUT_DIRECTORY)
set_target_properties(ngraph PROPERTIES
LIBRARY_OUTPUT_DIRECTORY ${COMMON_LIBRARY_OUTPUT_DIRECTORY})
else()
set_target_properties(ngraph PROPERTIES
LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
endif()
message(STATUS "LIBRARY_OUTPUT_DIRECTORY set to: ${COMMON_LIBRARY_OUTPUT_DIRECTORY}")
target_include_directories(ngraph PUBLIC "${NGRAPH_INCLUDE_PATH}")
if(NGRAPH_CPU_ENABLE AND LLVM_LINK_LIBS)
target_link_libraries(ngraph LINK_PRIVATE ${LLVM_LINK_LIBS})
endif()
if (APPLE)
set_property(TARGET ngraph PROPERTY PREFIX "lib")
set_property(TARGET ngraph PROPERTY OUTPUT_NAME "ngraph.so")
set_property(TARGET ngraph PROPERTY SUFFIX "")
else()
include_directories("${MKLDNN_INCLUDE_DIR}")
endif()
if(NGRAPH_CPU_ENABLE AND MKLDNN_LIB_DIR)
target_link_libraries(ngraph LINK_PRIVATE mkldnn)
endif()
#-----------------------------------------------------------------------------------------------
# Installation logic...
#-----------------------------------------------------------------------------------------------
......@@ -146,3 +201,11 @@ install(DIRECTORY
endif()
add_dependencies(ngraph eigen)
if(NGRAPH_CPU_ENABLE AND LLVM_INCLUDE_DIR)
add_dependencies(ngraph ext_llvm)
endif()
if(NGRAPH_CPU_ENABLE AND MKLDNN_INCLUDE_DIR)
add_dependencies(ngraph ext_mkldnn)
endif()
......@@ -30,9 +30,6 @@
using namespace ngraph;
/// @brief Make a zero matching a value type.
std::shared_ptr<Node> make_zero(const std::shared_ptr<const ValueType>& value_type);
std::shared_ptr<Node> make_zero(const std::shared_ptr<const TensorViewType>& tensor_view_type)
{
std::shared_ptr<Node> zero =
......@@ -50,34 +47,6 @@ std::shared_ptr<Node> make_zero(const std::shared_ptr<const TensorViewType>& ten
return zero;
}
std::shared_ptr<Node> make_zero(const std::shared_ptr<const TupleType>& tuple_type)
{
std::vector<std::shared_ptr<Node>> elements;
for (auto& value_type : tuple_type->get_element_types())
{
elements.push_back(make_zero(value_type));
}
return std::make_shared<op::Tuple>(elements);
}
std::shared_ptr<Node> make_zero(const std::shared_ptr<const ValueType>& value_type)
{
std::shared_ptr<const TensorViewType> tensor_view_type =
std::dynamic_pointer_cast<const TensorViewType>(value_type);
if (nullptr != tensor_view_type)
{
return (make_zero(tensor_view_type));
}
std::shared_ptr<const TupleType> tuple_type =
std::dynamic_pointer_cast<const TupleType>(value_type);
if (nullptr != tuple_type)
{
return make_zero(tuple_type);
}
// Should be impossible
throw ngraph_error("Unknown value type");
}
autodiff::Adjoints::Adjoints(const std::shared_ptr<Node>& y, const std::shared_ptr<Node>& c)
{
// Pass 1 determines which nodes contribute to y as well as setting up a reverse
......@@ -143,7 +112,7 @@ std::shared_ptr<Node> autodiff::Adjoints::get(const std::shared_ptr<Node>& x)
auto adjoint_it = m_adjoint_map.find(x.get());
if (m_adjoint_map.end() == adjoint_it)
{
auto result = make_zero(x->get_value_type());
auto result = make_zero(x->get_outputs().at(0).get_tensor_view_type());
adjoint_it = m_adjoint_map.insert({x.get(), result}).first;
}
return adjoint_it->second;
......@@ -160,6 +129,6 @@ void autodiff::Adjoints::add_delta(const std::shared_ptr<Node>& x,
}
else
{
m_adjoint_map.insert({x.get(), std::make_shared<op::Add>(adjoint_it->second, delta)});
adjoint_it->second = std::make_shared<op::Add>(adjoint_it->second, delta);
}
}
......@@ -17,8 +17,6 @@
#include <memory>
#include <unordered_map>
#include "ngraph/runtime/parameterized_tensor_view.hpp"
namespace ngraph
{
class Node;
......@@ -64,30 +62,5 @@ namespace ngraph
/// @param f is f(X_i...)
/// @returns f'(X_i..., c) where f'(x_i, ..., c)_j is backprop for X_j
std::shared_ptr<Function> backprop_function(const std::shared_ptr<Function>& f);
template <typename ET>
std::vector<std::shared_ptr<runtime::ParameterizedTensorView<ET>>> backprop_derivative(
const std::shared_ptr<runtime::Manager>& manager,
const std::shared_ptr<runtime::Backend>& backend,
const std::shared_ptr<Function>& f,
const std::vector<std::shared_ptr<runtime::ParameterizedTensorView<ET>>>& args);
extern template std::vector<
std::shared_ptr<runtime::ParameterizedTensorView<ngraph::element::Float32>>>
backprop_derivative<ngraph::element::Float32>(
const std::shared_ptr<runtime::Manager>& manager,
const std::shared_ptr<runtime::Backend>& backend,
const std::shared_ptr<Function>& f,
const std::vector<
std::shared_ptr<runtime::ParameterizedTensorView<element::Float32>>>& args);
extern template std::vector<
std::shared_ptr<runtime::ParameterizedTensorView<ngraph::element::Float64>>>
backprop_derivative<ngraph::element::Float64>(
const std::shared_ptr<runtime::Manager>& manager,
const std::shared_ptr<runtime::Backend>& backend,
const std::shared_ptr<Function>& f,
const std::vector<
std::shared_ptr<runtime::ParameterizedTensorView<element::Float64>>>& args);
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <clang/CodeGen/ObjectFilePCHContainerOperations.h>
#include <clang/Driver/DriverDiagnostic.h>
#include <clang/Driver/Options.h>
#include <clang/Frontend/CompilerInstance.h>
#include <clang/Frontend/CompilerInvocation.h>
#include <clang/Frontend/FrontendDiagnostic.h>
#include <clang/Frontend/TextDiagnosticBuffer.h>
#include <clang/Frontend/TextDiagnosticPrinter.h>
#include <clang/Frontend/Utils.h>
#include <clang/FrontendTool/Utils.h>
#include <clang/Lex/Preprocessor.h>
#include <clang/Lex/PreprocessorOptions.h>
#include <llvm/ADT/Statistic.h>
#include <llvm/LinkAllPasses.h>
#include <llvm/Option/Arg.h>
#include <llvm/Option/ArgList.h>
#include <llvm/Option/OptTable.h>
#include <llvm/Support/ErrorHandling.h>
#include <llvm/Support/ManagedStatic.h>
#include <llvm/Support/Signals.h>
#include <llvm/Support/TargetSelect.h>
#include <llvm/Support/Timer.h>
#include <llvm/Support/raw_ostream.h>
#include <clang/Basic/DiagnosticOptions.h>
#include <clang/Basic/TargetInfo.h>
#include <clang/CodeGen/CodeGenAction.h>
#include <clang/Frontend/CompilerInstance.h>
#include <clang/Frontend/TextDiagnosticPrinter.h>
#include <llvm/Support/TargetSelect.h>
#include <llvm/ExecutionEngine/ExecutionEngine.h>
#include <llvm/ExecutionEngine/SectionMemoryManager.h>
#include "ngraph/codegen/compiler.hpp"
// TODO: Fix leaks
using namespace clang;
using namespace llvm;
using namespace llvm::opt;
using namespace std;
using namespace ngraph::codegen;
static std::string GetExecutablePath(const char* Argv0)
{
// This just needs to be some symbol in the binary; C++ doesn't
// allow taking the address of ::main however.
void* MainAddr = reinterpret_cast<void*>(GetExecutablePath);
return llvm::sys::fs::getMainExecutable(Argv0, MainAddr);
}
execution_state::execution_state()
: m_execution_engine{nullptr}
, precompiled_headers_enabled(false)
, debuginfo_enabled(false)
{
}
execution_state::~execution_state()
{
}
std::unique_ptr<llvm::Module> execution_state::compile(const string& source, const string& name)
{
llvm::InitializeAllTargets();
llvm::InitializeAllTargetMCs();
llvm::InitializeAllAsmPrinters();
llvm::InitializeAllAsmParsers();
// Prepare compilation arguments
vector<const char*> args;
args.push_back(name.c_str());
// Prepare DiagnosticEngine
DiagnosticOptions DiagOpts;
TextDiagnosticPrinter* textDiagPrinter = new clang::TextDiagnosticPrinter(errs(), &DiagOpts);
IntrusiveRefCntPtr<clang::DiagnosticIDs> pDiagIDs;
DiagnosticsEngine* pDiagnosticsEngine =
new DiagnosticsEngine(pDiagIDs, &DiagOpts, textDiagPrinter);
// Create and initialize CompilerInstance
std::unique_ptr<CompilerInstance> Clang(new CompilerInstance());
Clang->createDiagnostics();
// Initialize CompilerInvocation
CompilerInvocation::CreateFromArgs(
Clang->getInvocation(), &args[0], &args[0] + args.size(), *pDiagnosticsEngine);
// Infer the builtin include path if unspecified.
if (Clang->getHeaderSearchOpts().UseBuiltinIncludes &&
Clang->getHeaderSearchOpts().ResourceDir.empty())
{
void* MainAddr = reinterpret_cast<void*>(GetExecutablePath);
auto path = CompilerInvocation::GetResourcesPath(args[0], MainAddr);
Clang->getHeaderSearchOpts().ResourceDir = path;
}
auto& HSO = Clang->getInvocation().getHeaderSearchOpts();
// Add base toolchain-supplied header paths
// Ideally one would use the Linux toolchain definition in clang/lib/Driver/ToolChains.h
// But that's a private header and isn't part of the public libclang API
// Instead of re-implementing all of that functionality in a custom toolchain
// just hardcode the paths relevant to frequently used build/test machines for now
HSO.AddPath(CLANG_BUILTIN_HEADERS_PATH, clang::frontend::System, false, false);
HSO.AddPath("/usr/include/x86_64-linux-gnu", clang::frontend::System, false, false);
HSO.AddPath("/usr/include", clang::frontend::System, false, false);
// Add C++ standard library headers
// Debian-like + GCC 4.8 libstdc++
HSO.AddPath("/usr/include/x86_64-linux-gnu/c++/4.8", clang::frontend::System, false, false);
HSO.AddPath("/usr/include/c++/4.8", clang::frontend::System, false, false);
// Debian-like + GCC 5 libstdc++
HSO.AddPath("/usr/include/x86_64-linux-gnu/c++/5", clang::frontend::System, false, false);
HSO.AddPath("/usr/include/c++/5", clang::frontend::System, false, false);
HSO.AddPath(EIGEN_HEADERS_PATH, clang::frontend::System, false, false);
HSO.AddPath(NGRAPH_HEADERS_PATH, clang::frontend::System, false, false);
// Language options
// These are the C++ features needed to compile ngraph headers
// and any dependencies like Eigen
auto LO = Clang->getInvocation().getLangOpts();
LO->CPlusPlus = 1;
LO->CPlusPlus11 = 1;
LO->Bool = 1;
LO->Exceptions = 1;
LO->CXXExceptions = 1;
LO->WChar = 1;
LO->RTTI = 1;
// Enable OpenMP for Eigen
LO->OpenMP = 1;
LO->OpenMPUseTLS = 1;
// CodeGen options
auto& CGO = Clang->getInvocation().getCodeGenOpts();
CGO.OptimizationLevel = 3;
CGO.RelocationModel = "static";
CGO.ThreadModel = "posix";
CGO.FloatABI = "hard";
CGO.OmitLeafFramePointer = 1;
CGO.VectorizeLoop = 1;
CGO.VectorizeSLP = 1;
CGO.CXAAtExit = 0;
if (debuginfo_enabled)
{
CGO.setDebugInfo(codegenoptions::FullDebugInfo);
}
if (precompiled_headers_enabled)
{
// Preprocessor options
auto& PPO = Clang->getInvocation().getPreprocessorOpts();
PPO.ImplicitPCHInclude = "ngcpu.pch";
PPO.DisablePCHValidation = 1;
}
// Enable various target features
// Most of these are for Eigen
auto& TO = Clang->getInvocation().getTargetOpts();
// TODO: This needs to be configurable and selected carefully
TO.CPU = "broadwell";
TO.FeaturesAsWritten.emplace_back("+sse");
TO.FeaturesAsWritten.emplace_back("+sse2");
TO.FeaturesAsWritten.emplace_back("+sse3");
TO.FeaturesAsWritten.emplace_back("+ssse3");
TO.FeaturesAsWritten.emplace_back("+sse4.1");
TO.FeaturesAsWritten.emplace_back("+sse4.2");
TO.FeaturesAsWritten.emplace_back("+avx");
TO.FeaturesAsWritten.emplace_back("+avx2");
TO.FeaturesAsWritten.emplace_back("+fma");
// Map code filename to a memoryBuffer
StringRef source_ref(source);
unique_ptr<MemoryBuffer> buffer = MemoryBuffer::getMemBufferCopy(source_ref);
Clang->getInvocation().getPreprocessorOpts().addRemappedFile(name, buffer.get());
// Create and execute action
CodeGenAction* compilerAction = new EmitCodeGenOnlyAction();
Clang->ExecuteAction(*compilerAction);
buffer.release();
return compilerAction->takeModule();
}
bool execution_state::add_module(std::unique_ptr<llvm::Module>& module)
{
if (module)
{
if (!m_execution_engine)
{
m_execution_engine = llvm::EngineBuilder(move(module))
.setEngineKind(llvm::EngineKind::JIT)
.setOptLevel(llvm::CodeGenOpt::Aggressive)
.setErrorStr(&jit_error)
.create();
if (!m_execution_engine)
{
return false;
}
}
}
else
{
return false;
}
return true;
}
void execution_state::finalize()
{
if (m_execution_engine)
{
m_execution_engine->finalizeObject();
m_execution_engine->runStaticConstructorsDestructors(false);
}
else
{
throw std::runtime_error(
"Error in finalize: " +
(jit_error.empty() ? "Could not create an execution engine" : jit_error));
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <functional>
#include <memory>
#include <string>
#include <llvm/ExecutionEngine/MCJIT.h> // forces JIT to link in
#include <llvm/ExecutionEngine/SectionMemoryManager.h>
#include <llvm/Option/Arg.h>
namespace ngraph
{
namespace codegen
{
class module;
class execution_state;
}
}
class ngraph::codegen::module
{
public:
private:
std::unique_ptr<llvm::Module> m_module;
};
class ngraph::codegen::execution_state : public llvm::SectionMemoryManager
{
public:
execution_state();
~execution_state();
void set_precompiled_headers_enabled(bool state) { precompiled_headers_enabled = state; }
bool is_precompiled_headers_enabled() { return precompiled_headers_enabled; }
void set_debuginfo_enabled(bool state) { debuginfo_enabled = state; }
bool is_debuginfo_enabled() { return debuginfo_enabled; }
std::unique_ptr<llvm::Module> compile(const std::string& source, const std::string& name = "");
bool add_module(std::unique_ptr<llvm::Module>&);
void finalize();
template <typename ftype>
std::function<ftype> find_function(const std::string& func_name)
{
auto f = m_execution_engine->getPointerToNamedFunction(func_name);
return f_cast<ftype>(f);
}
private:
llvm::ExecutionEngine* m_execution_engine;
std::string jit_error;
bool precompiled_headers_enabled;
bool debuginfo_enabled;
template <typename signature>
std::function<signature> f_cast(void* f)
{
return static_cast<signature*>(reinterpret_cast<signature*>(f));
}
};
......@@ -14,12 +14,12 @@
#include "ngraph/descriptor/input.hpp"
#include "ngraph/descriptor/output.hpp"
#include "ngraph/node.hpp"
using namespace ngraph;
using namespace descriptor;
Input::Input(
const std::shared_ptr<Node>& node, size_t index, size_t argno, size_t arg_index, Output& output)
Input::Input(Node* node, size_t index, size_t argno, size_t arg_index, Output& output)
: m_node(node)
, m_index(index)
, m_argno(argno)
......@@ -31,7 +31,7 @@ Input::Input(
std::shared_ptr<Node> Input::get_node()
{
return m_node.lock();
return m_node->shared_from_this();
}
const Tensor& Input::get_tensor() const
......@@ -43,3 +43,18 @@ Tensor& Input::get_tensor()
{
return m_output.get_tensor();
}
std::shared_ptr<const TensorView> Input::get_tensor_view() const
{
return m_output.get_tensor_view();
}
std::shared_ptr<TensorView> Input::get_tensor_view()
{
return m_output.get_tensor_view();
}
std::shared_ptr<const TensorViewType> Input::get_tensor_view_type() const
{
return m_output.get_tensor_view()->get_tensor_view_type();
}
......@@ -17,6 +17,7 @@
#include <memory>
#include "ngraph/descriptor/tensor.hpp"
#include "ngraph/types/type.hpp"
namespace ngraph
{
......@@ -37,26 +38,41 @@ namespace ngraph
/// @param argno The position of the argument with this tensor
/// @param arg_index The position of the tensor within the argument's tensors
/// @param output The output that supplies a value for this input
Input(const std::shared_ptr<Node>& node,
size_t index,
size_t argno,
size_t arg_index,
Output& output);
Input(Node* node, size_t index, size_t argno, size_t arg_index, Output& output);
/// @return the node that this is an input of
std::shared_ptr<Node> get_node();
/// @return the position of the node argument that uses this input
size_t get_argno() const { return m_argno; }
/// @return the position within the node argument of this tensor
size_t get_arg_index() const { return m_arg_index; }
/// @return the position within all supplied tensors of this input
size_t get_index() const { return m_index; }
// @return the connected output
const Output& get_output() const { return m_output; }
// @return the connected output
Output& get_output() { return m_output; }
// @return the tensor of the connected output
const Tensor& get_tensor() const;
// @return the tensor of the connected output
Tensor& get_tensor();
/// @return the tensor view for the connected output
std::shared_ptr<const TensorView> get_tensor_view() const;
/// @return the tensor view for the connected output
std::shared_ptr<TensorView> get_tensor_view();
/// @return the tensor view type for the connected output
std::shared_ptr<const TensorViewType> get_tensor_view_type() const;
protected:
std::weak_ptr<Node> m_node; // The node we are an input for
size_t m_index; // Index into all input tensors
size_t m_argno; // Arg number for this input
size_t m_arg_index; // Index into arg's tensors
Node* m_node; // The node we are an input for
size_t m_index; // Index into all input tensors
size_t m_argno; // Arg number for this input
size_t m_arg_index; // Index into arg's tensors
Output& m_output;
private:
......
......@@ -16,6 +16,7 @@
#include "ngraph/except.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/types/element_type.hpp"
#include "ngraph/types/type.hpp"
using namespace ngraph::descriptor::layout;
using ngraph::Shape;
......
......@@ -15,6 +15,7 @@
#include "ngraph/descriptor/layout/tensor_view_layout.hpp"
#include "ngraph/descriptor/tensor_view.hpp"
#include "ngraph/types/element_type.hpp"
#include "ngraph/types/type.hpp"
using namespace ngraph::descriptor::layout;
......
......@@ -14,14 +14,13 @@
#include "ngraph/descriptor/output.hpp"
#include "ngraph/descriptor/input.hpp"
#include "ngraph/node.hpp"
using namespace std;
using namespace ngraph;
using namespace ngraph::descriptor;
Output::Output(const std::shared_ptr<Node>& node,
size_t index,
const std::shared_ptr<TensorView>& tensor_view)
Output::Output(Node* node, size_t index, const std::shared_ptr<TensorView>& tensor_view)
: m_node(node)
, m_index(index)
, m_tensor_view(tensor_view)
......@@ -36,7 +35,7 @@ void Output::add_input(Input* input)
std::shared_ptr<Node> Output::get_node() const
{
return m_node.lock();
return m_node->shared_from_this();
}
const Tensor& Output::get_tensor() const
......
......@@ -38,9 +38,7 @@ namespace ngraph
/// @param node Node that owns this output.
/// @param index Position of the output tensor in all output tensors
/// @param tensor_view The view of this tensor; where the value will be written
Output(const std::shared_ptr<Node>& node,
size_t index,
const std::shared_ptr<TensorView>& tensor_view);
Output(Node* node, size_t index, const std::shared_ptr<TensorView>& tensor_view);
std::shared_ptr<Node> get_node() const;
size_t get_index() const { return m_index; }
......@@ -49,9 +47,14 @@ namespace ngraph
const std::set<Input*>& get_inputs() const { return m_inputs; }
const Tensor& get_tensor() const;
Tensor& get_tensor();
/// @return the tensor view type for the connected output
std::shared_ptr<const TensorViewType> get_tensor_view_type() const
{
return get_tensor_view()->get_tensor_view_type();
}
protected:
std::weak_ptr<Node> m_node;
Node* m_node;
size_t m_index;
std::shared_ptr<TensorView> m_tensor_view;
std::set<Input*> m_inputs;
......
......@@ -20,38 +20,61 @@ using namespace ngraph;
atomic<size_t> Node::m_next_instance_id(0);
Node::Node(const std::vector<shared_ptr<Node>>& arguments, shared_ptr<const ValueType> value_type)
: m_arguments(arguments)
, m_value_type(value_type)
Node::Node(const std::string& node_type, const std::vector<shared_ptr<Node>>& arguments)
: m_node_type(node_type)
, m_arguments(arguments)
, m_instance_id(m_next_instance_id.fetch_add(1))
, m_is_output(false)
{
// Add this node as a user of each argument.
for (auto node : m_arguments)
size_t i = 0;
size_t argno = 0;
for (auto arg : m_arguments)
{
node->m_users.insert(this);
arg->m_users.insert(this);
size_t arg_index = 0;
for (descriptor::Output& output : arg->get_outputs())
{
m_inputs.emplace_back(this, i, argno, arg_index++, output);
i++;
}
argno++;
}
}
Node::Node()
: Node({}, nullptr)
{
}
Node::Node(std::shared_ptr<const ValueType> value_type)
: Node({}, value_type)
Node::~Node()
{
}
Node::~Node()
void Node::assert_value_type(const shared_ptr<const ValueType>& value_type) const
{
if (*m_value_type != *value_type)
{
throw ngraph_error("Setting value type to a different ValueType");
}
}
void Node::set_value_type_checked(const shared_ptr<const ValueType>& value_type)
{
if (nullptr == m_value_type)
{
m_value_type = value_type;
if (nullptr != value_type)
{
m_value_type = value_type;
vector<std::shared_ptr<const TensorViewType>> tensor_view_types;
m_value_type->collect_tensor_views(tensor_view_types);
size_t i = 0;
for (auto tvt : tensor_view_types)
{
auto tensor_view_descriptor = make_shared<descriptor::PrimaryTensorView>(
tvt,
ngraph::descriptor::Tensor::make_tensor_name(this, i),
is_output(),
is_parameter());
m_outputs.emplace_back(this, i, tensor_view_descriptor);
i++;
}
}
}
else
{
......@@ -64,51 +87,22 @@ void Node::set_value_type_checked(const shared_ptr<const ValueType>& value_type)
std::shared_ptr<const ValueType> Node::get_value_type()
{
if (nullptr == m_value_type)
{
propagate_types();
}
return m_value_type;
}
const std::shared_ptr<const ValueType> Node::get_value_type() const
{
if (nullptr == m_value_type)
{
const_cast<Node*>(this)->propagate_types();
}
return m_value_type;
}
void Node::assign_tensors()
std::deque<descriptor::Output>& Node::get_outputs()
{
vector<std::shared_ptr<const TensorViewType>> tensor_view_types;
get_value_type()->collect_tensor_views(tensor_view_types);
std::shared_ptr<Node> shared_this = shared_from_this();
size_t i = 0;
for (auto tvt : tensor_view_types)
{
auto tensor_view_descriptor = make_shared<descriptor::PrimaryTensorView>(
tvt,
ngraph::descriptor::Tensor::make_tensor_name(this, i),
is_output(),
is_parameter());
m_outputs.emplace_back(shared_this, i, tensor_view_descriptor);
i++;
}
return m_outputs;
}
i = 0;
size_t argno = 0;
for (auto arg : get_arguments())
{
size_t arg_index = 0;
for (descriptor::Output& output : arg->get_outputs())
{
m_inputs.emplace_back(shared_this, i, argno, arg_index++, output);
i++;
}
argno++;
}
const std::deque<descriptor::Output>& Node::get_outputs() const
{
return m_outputs;
}
bool Node::is_parameter() const
......
......@@ -15,6 +15,7 @@
#pragma once
#include <atomic>
#include <deque>
#include <memory>
#include <set>
#include <string>
......@@ -43,10 +44,7 @@ namespace ngraph
friend class autodiff::Adjoints;
protected:
Node(const Nodes& arguments, std::shared_ptr<const ValueType> value_type = nullptr);
Node();
Node(std::shared_ptr<const ValueType> value_type);
Node(const std::string& node_type, const Nodes& arguments);
virtual ~Node();
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
......@@ -56,21 +54,14 @@ namespace ngraph
public:
/// The class name, must not contain spaces
virtual std::string description() const = 0;
std::string description() const { return m_node_type; }
std::string get_name() const;
void set_name(const std::string& name);
/// Propagate types and check arguments for consistency
virtual void propagate_types() = 0;
/// Assign Input and Output vectors
// This might later need to be virtual.
void assign_tensors();
const Nodes& get_arguments() const { return m_arguments; }
void clear_arguments() { m_arguments.clear(); }
const std::multiset<Node*>& users() const { return m_users; }
virtual std::string get_node_id() const;
std::string get_node_id() const;
/// Return true if this has the same implementing class as node. This
/// will be used by the pattern matcher when comparing a pattern
......@@ -83,14 +74,10 @@ namespace ngraph
std::shared_ptr<const ValueType> get_value_type();
const std::shared_ptr<const ValueType> get_value_type() const;
void set_value_type(const element::Type& element_type, const Shape& shape)
void assert_value_type(const std::shared_ptr<const ValueType>& value_type) const;
void assert_value_type(const element::Type& element_type, const Shape& shape) const
{
m_value_type = std::make_shared<TensorViewType>(element_type, shape);
}
void set_value_type(const std::shared_ptr<const ValueType>& value_type)
{
m_value_type = value_type;
assert_value_type(std::make_shared<TensorViewType>(element_type, shape));
}
// Set the value type if it has not already been set; otherwise, ensure that
......@@ -108,8 +95,8 @@ namespace ngraph
std::deque<descriptor::Input>& get_inputs() { return m_inputs; }
const std::deque<descriptor::Input>& get_inputs() const { return m_inputs; }
std::deque<descriptor::Output>& get_outputs() { return m_outputs; }
const std::deque<descriptor::Output>& get_outputs() const { return m_outputs; }
std::deque<descriptor::Output>& get_outputs();
const std::deque<descriptor::Output>& get_outputs() const;
std::unordered_set<descriptor::Tensor*> liveness_live_list;
std::unordered_set<descriptor::Tensor*> liveness_new_list;
std::unordered_set<descriptor::Tensor*> liveness_free_list;
......@@ -117,9 +104,14 @@ namespace ngraph
std::shared_ptr<Node> backprop_node(const std::shared_ptr<Node>& x,
const std::shared_ptr<Node>& c);
/// Returns the shape if this node has tensor type, othetwise error.
/// Returns the shape if this node has tensor type, otherwise an ngraph-error is thrown.
const Shape& get_shape() const { return m_value_type->get_shape(); }
const element::Type& get_element_type() const { return m_value_type->get_element_type(); }
virtual std::shared_ptr<Node>
copy_with_new_args(const std::vector<std::shared_ptr<Node>>& new_args) const = 0;
protected:
std::string m_node_type;
Nodes m_arguments;
std::shared_ptr<const ValueType> m_value_type;
std::multiset<Node*> m_users;
......
......@@ -14,6 +14,8 @@
#pragma once
#include <memory>
#include "ngraph/ops/op.hpp"
namespace ngraph
......@@ -46,11 +48,17 @@ namespace ngraph
///
/// \param arg Node that produces the input tensor.
Abs(const std::shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic(arg)
: UnaryElementwiseArithmetic("Abs", arg)
{
}
virtual std::string description() const override { return "Abs"; }
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 1)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Abs>(new_args.at(0));
}
};
}
}
......@@ -14,6 +14,8 @@
#pragma once
#include <memory>
#include "ngraph/ops/op.hpp"
namespace ngraph
......@@ -46,11 +48,17 @@ namespace ngraph
///
/// \param arg Node that produces the input tensor.
Acos(const std::shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic(arg)
: UnaryElementwiseArithmetic("Acos", arg)
{
}
virtual std::string description() const override { return "Acos"; }
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 1)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Acos>(new_args.at(0));
}
};
}
}
......@@ -14,6 +14,8 @@
#pragma once
#include <memory>
#include "ngraph/ops/op.hpp"
namespace ngraph
......@@ -48,10 +50,18 @@ namespace ngraph
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
Add(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseArithmetic(arg0, arg1)
: BinaryElementwiseArithmetic("Add", arg0, arg1)
{
}
virtual std::string description() const override { return "Add"; }
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 2)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Add>(new_args.at(0), new_args.at(1));
}
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta) override;
......
......@@ -14,6 +14,8 @@
#pragma once
#include <memory>
#include "ngraph/ops/op.hpp"
namespace ngraph
......@@ -46,11 +48,17 @@ namespace ngraph
///
/// \param arg Node that produces the input tensor.
Asin(const std::shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic(arg)
: UnaryElementwiseArithmetic("Asin", arg)
{
}
virtual std::string description() const override { return "Asin"; }
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 1)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Asin>(new_args.at(0));
}
};
}
}
......@@ -14,6 +14,8 @@
#pragma once
#include <memory>
#include "ngraph/ops/op.hpp"
namespace ngraph
......@@ -46,11 +48,17 @@ namespace ngraph
///
/// \param arg Node that produces the input tensor.
Atan(const std::shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic(arg)
: UnaryElementwiseArithmetic("Atan", arg)
{
}
virtual std::string description() const override { return "Atan"; }
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 1)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Atan>(new_args.at(0));
}
};
}
}
......@@ -19,29 +19,23 @@
using namespace std;
using namespace ngraph;
using namespace ngraph::op;
void BinaryElementwiseBuiltin::propagate_types()
op::BinaryElementwise::BinaryElementwise(
const std::string& node_type,
std::function<const element::Type&(const element::Type&, const element::Type&)>
element_type_function,
const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
: RequiresTensorViewArgs(node_type, Nodes{arg0, arg1})
{
if (m_arguments.size() != 2)
{
throw ngraph_error("Wrong number of arguments.");
}
auto arg0_tensor_type =
dynamic_pointer_cast<const TensorViewType>(m_arguments.at(0)->get_value_type());
auto arg1_tensor_type =
dynamic_pointer_cast<const TensorViewType>(m_arguments.at(1)->get_value_type());
if (nullptr == arg0_tensor_type || nullptr == arg1_tensor_type)
{
throw ngraph_error("Arguments must be tensor views");
}
auto arg0_tensor_type = get_inputs().at(0).get_tensor_view_type();
auto arg1_tensor_type = get_inputs().at(1).get_tensor_view_type();
if (arg0_tensor_type->get_shape() != arg1_tensor_type->get_shape())
{
throw ngraph_error("Arguments must have the same tensor view shape");
}
const element::Type& result_element_type = propagate_element_types(
const element::Type& result_element_type = element_type_function(
arg0_tensor_type->get_element_type(), arg1_tensor_type->get_element_type());
set_value_type_checked(
......
......@@ -16,20 +16,28 @@
using namespace std;
using namespace ngraph;
using namespace ngraph::op;
const element::Type& BinaryElementwiseArithmetic::propagate_element_types(
const element::Type& arg0_element_type, const element::Type& arg1_element_type) const
{
if (arg0_element_type != arg1_element_type)
{
throw ngraph_error("Arguments must have the same tensor view element type");
}
op::BinaryElementwiseArithmetic::BinaryElementwiseArithmetic(const std::string& node_type,
const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
: BinaryElementwise(
node_type,
[](const element::Type& arg0_element_type,
const element::Type& arg1_element_type) -> const element::Type& {
if (arg0_element_type != arg1_element_type)
{
throw ngraph_error("Arguments must have the same tensor view element type");
}
if (arg0_element_type == element::Bool::element_type())
{
throw ngraph_error("Operands for arithmetic operators must have numeric element type");
}
if (arg0_element_type == element::Bool::element_type())
{
throw ngraph_error(
"Operands for arithmetic operators must have numeric element type");
}
return arg0_element_type;
return arg0_element_type;
},
arg0,
arg1)
{
}
......@@ -16,15 +16,23 @@
using namespace std;
using namespace ngraph;
using namespace ngraph::op;
using namespace ngraph;
const element::Type& BinaryElementwiseComparison::propagate_element_types(
const element::Type& arg0_element_type, const element::Type& arg1_element_type) const
{
if (arg0_element_type != arg1_element_type)
{
throw ngraph_error("Arguments must have the same tensor view element type");
}
op::BinaryElementwiseComparison::BinaryElementwiseComparison(const std::string& node_type,
const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
: BinaryElementwise(node_type,
[](const element::Type& arg0_element_type,
const element::Type& arg1_element_type) -> const element::Type& {
if (arg0_element_type != arg1_element_type)
{
throw ngraph_error(
"Arguments must have the same tensor view element type");
}
return element::Bool::element_type();
return element::Bool::element_type();
},
arg0,
arg1)
{
}
......@@ -16,25 +16,16 @@
#include "ngraph/ops/sum.hpp"
using namespace std;
using namespace ngraph::op;
using namespace ngraph;
void Broadcast::propagate_types()
op::Broadcast::Broadcast(const std::shared_ptr<Node>& arg,
const Shape& shape,
const AxisSet& broadcast_axes)
: RequiresTensorViewArgs("Broadcast", {arg})
, m_shape(shape)
, m_broadcast_axes(broadcast_axes)
{
if (m_arguments.size() != 1)
{
throw ngraph_error("Wrong number of arguments.");
}
auto arg_type = m_arguments.at(0)->get_value_type();
if (nullptr == arg_type)
{
throw ngraph_error("Argument to broadcast is missing type.");
}
auto arg_tensor_view_type = dynamic_pointer_cast<const TensorViewType>(arg_type);
if (nullptr == arg_tensor_view_type)
{
throw ngraph_error("Argument to broadcast is not a tensor view");
}
auto arg_tensor_view_type = m_inputs.at(0).get_tensor_view_type();
vector<size_t> target_shape = m_shape;
for (auto i = m_broadcast_axes.rbegin(); i != m_broadcast_axes.rend(); ++i)
{
......@@ -48,8 +39,8 @@ void Broadcast::propagate_types()
make_shared<TensorViewType>(arg_tensor_view_type->get_element_type(), m_shape));
}
void ngraph::op::Broadcast::generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta)
void op::Broadcast::generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta)
{
auto x = m_arguments[0];
......
......@@ -56,7 +56,7 @@ namespace ngraph
/// | ------- | ----------------------------------------------- |
/// | NGVM | Implemented for scalars, matrices, and vectors. |
class Broadcast : public Builtin
class Broadcast : public RequiresTensorViewArgs
{
public:
/// \brief Constructs a conversion operation.
......@@ -67,19 +67,19 @@ namespace ngraph
/// remaining axes in shape must be the same as the shape of arg.
Broadcast(const std::shared_ptr<Node>& arg,
const Shape& shape,
const AxisSet& broadcast_axes)
: Builtin({arg})
, m_shape(shape)
, m_broadcast_axes(broadcast_axes)
const AxisSet& broadcast_axes);
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 1)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Broadcast>(new_args.at(0), m_shape, m_broadcast_axes);
}
virtual std::string description() const override { return "Broadcast"; }
virtual void propagate_types() override;
/// \return An set containing the indices of the broadcast axes (0-based).
const AxisSet& get_broadcast_axes() const { return m_broadcast_axes; }
protected:
const Shape& get_broadcast_shape() const { return m_shape; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta) override;
......
......@@ -46,11 +46,17 @@ namespace ngraph
///
/// \param arg Node that produces the input tensor.
Ceiling(const std::shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic(arg)
: UnaryElementwiseArithmetic("Ceiling", arg)
{
}
virtual std::string description() const override { return "Ceiling"; }
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 1)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Ceiling>(new_args.at(0));
}
};
}
}
......@@ -17,27 +17,18 @@
#include "ngraph/ops/concatenate.hpp"
using namespace std;
using namespace ngraph::op;
using namespace ngraph;
void Concat::propagate_types()
op::Concat::Concat(const Nodes& args, size_t concatenation_axis)
: RequiresTensorViewArgs("Concat", args)
, m_concatenation_axis(concatenation_axis)
{
if (m_arguments.size() < 1)
{
throw ngraph_error("At least one argument required");
}
auto arg0_type = m_arguments.at(0)->get_value_type();
if (nullptr == arg0_type)
{
throw ngraph_error("Argument to concat is missing type.");
}
auto arg0_tensor_view_type = dynamic_pointer_cast<const TensorViewType>(arg0_type);
if (nullptr == arg0_tensor_view_type)
{
throw ngraph_error("Argument to concat is not a tensor view");
}
auto arg0_tensor_view_type = get_inputs().at(0).get_tensor_view_type();
auto arg0_shape = arg0_tensor_view_type->get_shape();
if (m_concatenation_axis >= arg0_shape.size())
{
......@@ -47,20 +38,9 @@ void Concat::propagate_types()
size_t concatenation_axis_length = arg0_shape.at(m_concatenation_axis);
auto& arg0_element_type = arg0_tensor_view_type->get_element_type();
for (auto i = 1; i < m_arguments.size(); i++)
for (auto i = 1; i < get_inputs().size(); i++)
{
auto argi_type = m_arguments.at(i)->get_value_type();
if (nullptr == argi_type)
{
throw ngraph_error("Argument to concat is missing type.");
}
auto argi_tensor_view_type = dynamic_pointer_cast<const TensorViewType>(argi_type);
if (nullptr == argi_tensor_view_type)
{
throw ngraph_error("Argument to concat is not a tensor view");
}
auto argi_tensor_view_type = get_inputs().at(i).get_tensor_view_type();
auto argi_shape = argi_tensor_view_type->get_shape();
if (argi_shape.size() != arg0_shape.size())
{
......@@ -85,7 +65,6 @@ void Concat::propagate_types()
}
}
}
vector<size_t> concatenated_shape = arg0_shape;
concatenated_shape.at(m_concatenation_axis) = concatenation_axis_length;
......
......@@ -14,6 +14,8 @@
#pragma once
#include <memory>
#include "ngraph/ops/op.hpp"
namespace ngraph
......@@ -61,22 +63,21 @@ namespace ngraph
/// | Backend | Status |
/// | ------- | ------------------------------------- |
/// | NGVM | Implemented for vectors and matrices. |
class Concat : public Builtin
class Concat : public RequiresTensorViewArgs
{
public:
/// \brief Constructs a concatenation operation.
///
/// \param args The nodes producing the input tensors.
/// \param concatenation_axis The axis along which to concatenate the input tensors.
Concat(const Nodes& args, size_t concatenation_axis)
: Builtin(args)
, m_concatenation_axis(concatenation_axis)
Concat(const Nodes& args, size_t concatenation_axis);
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
return std::make_shared<Concat>(new_args, m_concatenation_axis);
}
virtual std::string description() const override { return "Concatenate"; }
virtual void propagate_types() override;
/// \return The concatenation axis.
size_t get_concatenation_axis() const { return m_concatenation_axis; }
protected:
......
......@@ -14,24 +14,44 @@
#include "ngraph/ops/constant.hpp"
using namespace ngraph::op;
using namespace ngraph;
void ConstantBase::propagate_types()
namespace
{
template <typename ET>
void check_value_strings(const std::vector<std::string>& value_strings)
{
auto result = ET::read(value_strings);
}
}
op::Constant::Constant(const element::Type& et,
const Shape& shape,
const std::vector<std::string>& value_strings)
: ConstantBase("Constant", std::make_shared<TensorViewType>(et, shape))
, m_value_strings(value_strings)
{
check_args();
}
template <typename ET>
void check_value_strings(const std::vector<std::string>& value_strings)
/// \brief Constructs a tensor constant with the same initialization value copied across the tensor.
///
/// \param et The element type of the tensor constant.
/// \param shape The shape of the tensor constant.
/// \param value_string A literal for initializing each tensor constant.
op::Constant::Constant(const element::Type& et, const Shape& shape, const std::string& value_string)
: ConstantBase("Constant", std::make_shared<TensorViewType>(et, shape))
, m_value_strings(ngraph::shape_size(shape), value_string)
{
auto result = ET::read(value_strings);
check_args();
}
void Constant::propagate_types()
void op::Constant::check_args()
{
// No actual type propagation is done here; however, we check the number of value strings and
// We check the number of value strings and
// also call check_value_strings just to make sure the result will be parseable at compile
// time. (It will throw an exception if not.)
auto tvt = std::dynamic_pointer_cast<const TensorViewType>(get_value_type());
auto tvt = std::dynamic_pointer_cast<const TensorViewType>(m_value_type);
if (nullptr == tvt)
{
throw ngraph_error("Constant does not have tensor view type");
......
......@@ -35,12 +35,11 @@ namespace ngraph
/// \brief Constructs a constant base-type node.
///
/// \param type The TensorViewType for the constant.
ConstantBase(const std::shared_ptr<TensorViewType>& type)
: Node({}, type)
ConstantBase(const std::string& node_type, const std::shared_ptr<TensorViewType>& type)
: Node(node_type, {})
{
set_value_type_checked(type);
}
virtual void propagate_types() override;
};
/// \brief Class for constants whose element types are known at C++ compile-time.
......@@ -82,18 +81,19 @@ namespace ngraph
/// \param value The value of the tensor constant.
ParameterizedConstant(
const Shape& shape,
typename std::shared_ptr<ngraph::runtime::ParameterizedTensorView<T>>& value)
: ConstantBase(std::make_shared<TensorViewType>(T::element_type(), shape))
const typename std::shared_ptr<ngraph::runtime::ParameterizedTensorView<T>>& value)
: ConstantBase("ParameterizedConstant",
std::make_shared<TensorViewType>(T::element_type(), shape))
, m_value(value)
{
}
virtual std::string description() const override { return "ParameterizedConstant"; }
virtual std::string get_node_id() const override
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
std::stringstream ss;
ss << description() << "_" /* << node_id() */;
return ss.str();
if (new_args.size() != 0)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<ParameterizedConstant<T>>(get_shape(), m_value);
}
/// \return The value of the tensor constant.
......@@ -103,7 +103,7 @@ namespace ngraph
}
protected:
std::shared_ptr<ngraph::runtime::ParameterizedTensorView<T>> m_value;
const std::shared_ptr<ngraph::runtime::ParameterizedTensorView<T>> m_value;
};
/// \brief A 32-bit floating-point tensor constant.
......@@ -154,36 +154,28 @@ namespace ngraph
/// \param value_strings A list of literals for initializing the tensor constant. There must be one literal for each element of the tensor; i.e., `value_strings.size()` must equal `ngraph::shape_size(shape)`.
Constant(const element::Type& et,
const Shape& shape,
const std::vector<std::string>& value_strings)
: ConstantBase(std::make_shared<TensorViewType>(et, shape))
, m_value_strings(value_strings)
{
}
const std::vector<std::string>& value_strings);
/// \brief Constructs a tensor constant with the same initialization value copied across the tensor.
///
/// \param et The element type of the tensor constant.
/// \param shape The shape of the tensor constant.
/// \param value_string A literal for initializing each tensor constant.
Constant(const element::Type& et, const Shape& shape, const std::string& value_string)
: ConstantBase(std::make_shared<TensorViewType>(et, shape))
, m_value_strings(ngraph::shape_size(shape), value_string)
{
}
Constant(const element::Type& et, const Shape& shape, const std::string& value_string);
virtual std::string description() const override { return "Constant"; }
virtual std::string get_node_id() const override
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
std::stringstream ss;
ss << description() << "_" /* << node_id() */;
return ss.str();
if (new_args.size() != 0)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Constant>(get_element_type(), get_shape(), m_value_strings);
}
/// \return The initialization literals for the tensor constant.
const std::vector<std::string>& get_value_strings() const { return m_value_strings; }
virtual void propagate_types() override;
protected:
void check_args();
const std::vector<std::string> m_value_strings;
};
}
......
......@@ -18,9 +18,13 @@
using namespace std;
using namespace ngraph;
using namespace ngraph::op;
const element::Type& Convert::propagate_element_types(const element::Type& arg_element_type) const
op::Convert::Convert(const std::shared_ptr<Node>& arg, const ngraph::element::Type& element_type)
: UnaryElementwise("Convert",
[&](const ngraph::element::Type& ignored) -> const ngraph::element::Type& {
return element_type;
},
arg)
, m_element_type(element_type)
{
return m_element_type;
}
......@@ -48,22 +48,24 @@ namespace ngraph
/// | Backend | Status |
/// | ------- | ------------------ |
/// | NGVM | Fully implemented. |
class Convert : public UnaryElementwiseBuiltin
class Convert : public UnaryElementwise
{
public:
/// \brief Constructs a conversion operation.
///
/// \param arg Node that produces the input tensor.
/// \param element_type Element type for the output tensor.
Convert(const std::shared_ptr<Node>& arg, const ngraph::element::Type& element_type)
: UnaryElementwiseBuiltin({arg})
, m_element_type(element_type)
Convert(const std::shared_ptr<Node>& arg, const ngraph::element::Type& element_type);
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 1)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Convert>(new_args.at(0), m_element_type);
}
virtual const element::Type&
propagate_element_types(const element::Type& arg_element_type) const override;
virtual std::string description() const override { return "Convert"; }
const element::Type& get_convert_element_type() const { return m_element_type; }
protected:
const ngraph::element::Type& m_element_type;
};
......
......@@ -46,11 +46,17 @@ namespace ngraph
///
/// \param arg Node that produces the input tensor.
Cos(const std::shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic(arg)
: UnaryElementwiseArithmetic("Cos", arg)
{
}
virtual std::string description() const override { return "Cos"; }
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 1)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Cos>(new_args.at(0));
}
};
}
}
......@@ -46,11 +46,17 @@ namespace ngraph
///
/// \param arg Node that produces the input tensor.
Cosh(const std::shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic(arg)
: UnaryElementwiseArithmetic("Cosh", arg)
{
}
virtual std::string description() const override { return "Cosh"; }
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 1)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Cosh>(new_args.at(0));
}
};
}
}
......@@ -48,14 +48,20 @@ namespace ngraph
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
Divide(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseArithmetic(arg0, arg1)
: BinaryElementwiseArithmetic("Divide", arg0, arg1)
{
}
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 2)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Divide>(new_args.at(0), new_args.at(1));
}
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta) override;
virtual std::string description() const override { return "Divide"; }
};
}
inline std::shared_ptr<ngraph::Node> operator/(const std::shared_ptr<ngraph::Node> arg0,
......
......@@ -23,18 +23,14 @@
#include "ngraph/shape.hpp"
using namespace std;
using namespace ngraph::op;
using namespace ngraph;
void Dot::propagate_types()
op::Dot::Dot(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: RequiresTensorViewArgs("Dot", {arg0, arg1})
{
auto arg0_tensor_type =
dynamic_pointer_cast<const TensorViewType>(m_arguments.at(0)->get_value_type());
auto arg1_tensor_type =
dynamic_pointer_cast<const TensorViewType>(m_arguments.at(1)->get_value_type());
if (nullptr == arg0_tensor_type || nullptr == arg1_tensor_type)
{
throw ngraph_error("Arguments to dot must be tensor views");
}
auto arg0_tensor_type = get_inputs().at(0).get_tensor_view_type();
auto arg1_tensor_type = get_inputs().at(1).get_tensor_view_type();
if (arg0_tensor_type->get_element_type() != arg1_tensor_type->get_element_type())
{
throw ngraph_error("Arguments to dot must have the same element type");
......@@ -108,8 +104,7 @@ ngraph::AxisVector range<ngraph::AxisVector>(size_t n)
return result;
}
void ngraph::op::Dot::generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta)
void op::Dot::generate_adjoints(autodiff::Adjoints& adjoints, const std::shared_ptr<Node>& delta)
{
auto x = m_arguments[0];
auto y = m_arguments[1];
......
......@@ -102,21 +102,23 @@ namespace ngraph
/// | Backend | Status |
/// | ------- | ---------------------------------------------- |
/// | NGVM | Implemented for `arg1` with rank of exactly 2. |
class Dot : public Builtin
class Dot : public RequiresTensorViewArgs
{
public:
/// \brief Constructs a dot product operation.
///
/// \param arg0 The node producing the first argument.
/// \param arg1 The node producing the second argument.
Dot(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: Builtin({arg0, arg1})
Dot(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 2)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Dot>(new_args.at(0), new_args.at(1));
}
virtual std::string description() const override { return "Dot"; }
virtual void propagate_types() override;
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta) override;
......
......@@ -48,10 +48,17 @@ namespace ngraph
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
Equal(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseComparison(arg0, arg1)
: BinaryElementwiseComparison("Equal", arg0, arg1)
{
}
virtual std::string description() const override { return "Equal"; }
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 2)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Equal>(new_args.at(0), new_args.at(1));
}
};
}
}
......@@ -46,14 +46,20 @@ namespace ngraph
///
/// \param arg Node that produces the input tensor.
Exp(const std::shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic(arg)
: UnaryElementwiseArithmetic("Exp", arg)
{
}
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 1)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Exp>(new_args.at(0));
}
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta) override;
virtual std::string description() const override { return "Exp"; }
};
}
}
......@@ -46,11 +46,17 @@ namespace ngraph
///
/// \param arg Node that produces the input tensor.
Floor(const std::shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic(arg)
: UnaryElementwiseArithmetic("Floor", arg)
{
}
virtual std::string description() const override { return "Floor"; }
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 1)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Floor>(new_args.at(0));
}
};
}
}
......@@ -16,9 +16,12 @@
#include "ngraph/function.hpp"
using namespace std;
using namespace ngraph::op;
using namespace ngraph;
void FunctionCall::propagate_types()
op::FunctionCall::FunctionCall(std::shared_ptr<Function> function,
const std::vector<std::shared_ptr<Node>>& args)
: Node("FunctionCall", args)
, m_function(function)
{
auto& function_params = m_function->get_parameters();
......
......@@ -14,8 +14,7 @@
#pragma once
#include "ngraph/ops/op.hpp"
#include "ngraph/pass/propagate_types.hpp"
#include "ngraph/node.hpp"
namespace ngraph
{
......@@ -46,7 +45,7 @@ namespace ngraph
/// | Backend | Status |
/// | ------- | ------------------ |
/// | NGVM | Fully implemented. |
class FunctionCall : public Builtin
class FunctionCall : public ngraph::Node
{
public:
/// \brief Constructs a function call operation.
......@@ -54,15 +53,14 @@ namespace ngraph
/// \param function The function to be called.
/// \param args The arguments for the function call.
FunctionCall(std::shared_ptr<Function> function,
const std::vector<std::shared_ptr<Node>>& args)
: Builtin(args)
, m_function(function)
const std::vector<std::shared_ptr<Node>>& args);
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
return std::make_shared<FunctionCall>(m_function, new_args);
}
virtual std::string description() const override { return "FunctionCall"; }
virtual void propagate_types() override;
/// \return The function to be called.
std::shared_ptr<Function> get_function() const { return m_function; }
protected:
......
......@@ -17,15 +17,12 @@
#include "ngraph/ops/get_tuple_element.hpp"
using namespace std;
using namespace ngraph::op;
using namespace ngraph;
void GetTupleElement::propagate_types()
op::GetTupleElement::GetTupleElement(const std::shared_ptr<Node>& arg, size_t n)
: Node("GetTupleElement", {arg})
, m_n{n}
{
if (m_arguments.size() != 1)
{
throw ngraph_error("Wrong number of arguments.");
}
auto arg0_tuple_type =
dynamic_pointer_cast<const TupleType>(m_arguments.at(0)->get_value_type());
if (nullptr == arg0_tuple_type)
......
......@@ -14,7 +14,7 @@
#pragma once
#include "ngraph/ops/op.hpp"
#include "ngraph/node.hpp"
namespace ngraph
{
......@@ -47,21 +47,23 @@ namespace ngraph
/// | Backend | Status |
/// | ------- | ------------------ |
/// | NGVM | Fully implemented. |
class GetTupleElement : public Builtin
class GetTupleElement : public ngraph::Node
{
public:
/// \brief Constructs a get-tuple-element operation.
///
/// \param arg The input tuple.
/// \param n The index of the tuple element to get.
GetTupleElement(const std::shared_ptr<Node>& arg, size_t n)
: Builtin({arg})
, m_n{n}
GetTupleElement(const std::shared_ptr<Node>& arg, size_t n);
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 1)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<GetTupleElement>(new_args.at(0), m_n);
}
virtual void propagate_types() override;
virtual std::string description() const override { return "GetTupleElement"; }
/// \return The index of the tuple element to get.
size_t get_n() const { return m_n; }
protected:
......
......@@ -48,10 +48,17 @@ namespace ngraph
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
Greater(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseComparison(arg0, arg1)
: BinaryElementwiseComparison("Greater", arg0, arg1)
{
}
virtual std::string description() const override { return "Greater"; }
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 2)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Greater>(new_args.at(0), new_args.at(1));
}
};
}
}
......@@ -48,10 +48,17 @@ namespace ngraph
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
GreaterEq(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseComparison(arg0, arg1)
: BinaryElementwiseComparison("GreaterEq", arg0, arg1)
{
}
virtual std::string description() const override { return "GreaterEq"; }
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 2)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<GreaterEq>(new_args.at(0), new_args.at(1));
}
};
}
}
......@@ -48,10 +48,17 @@ namespace ngraph
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
Less(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseComparison(arg0, arg1)
: BinaryElementwiseComparison("Less", arg0, arg1)
{
}
virtual std::string description() const override { return "Less"; }
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 2)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Less>(new_args.at(0), new_args.at(1));
}
};
}
}
......@@ -48,10 +48,17 @@ namespace ngraph
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
LessEq(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseComparison(arg0, arg1)
: BinaryElementwiseComparison("LessEq", arg0, arg1)
{
}
virtual std::string description() const override { return "LessEq"; }
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 2)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<LessEq>(new_args.at(0), new_args.at(1));
}
};
}
}
......@@ -46,14 +46,20 @@ namespace ngraph
///
/// \param arg Node that produces the input tensor.
Log(const std::shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic(arg)
: UnaryElementwiseArithmetic("Log", arg)
{
}
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 1)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Log>(new_args.at(0));
}
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta) override;
virtual std::string description() const override { return "Log"; }
};
}
}
......@@ -48,10 +48,18 @@ namespace ngraph
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
Maximum(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseArithmetic(arg0, arg1)
: BinaryElementwiseArithmetic("Maximum", arg0, arg1)
{
}
virtual std::string description() const override { return "Maximum"; }
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 2)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Maximum>(new_args.at(0), new_args.at(1));
}
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta) override;
......
......@@ -48,10 +48,18 @@ namespace ngraph
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
Minimum(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseArithmetic(arg0, arg1)
: BinaryElementwiseArithmetic("Minimum", arg0, arg1)
{
}
virtual std::string description() const override { return "Minimum"; }
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 2)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Minimum>(new_args.at(0), new_args.at(1));
}
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta) override;
......
......@@ -48,11 +48,18 @@ namespace ngraph
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
Multiply(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseArithmetic(arg0, arg1)
: BinaryElementwiseArithmetic("Multiply", arg0, arg1)
{
}
virtual std::string description() const override { return "Multiply"; }
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 2)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Multiply>(new_args.at(0), new_args.at(1));
}
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta) override;
......
......@@ -46,14 +46,20 @@ namespace ngraph
///
/// \param arg Node that produces the input tensor.
Negative(const std::shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic(arg)
: UnaryElementwiseArithmetic("Negative", arg)
{
}
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 1)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Negative>(new_args.at(0));
}
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta) override;
virtual std::string description() const override { return "Negative"; }
};
}
inline std::shared_ptr<ngraph::Node> operator-(const std::shared_ptr<ngraph::Node> arg0)
......
......@@ -48,10 +48,17 @@ namespace ngraph
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
NotEqual(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseComparison(arg0, arg1)
: BinaryElementwiseComparison("NotEqual", arg0, arg1)
{
}
virtual std::string description() const override { return "NotEqual"; }
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 2)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<NotEqual>(new_args.at(0), new_args.at(1));
}
};
}
}
......@@ -13,9 +13,26 @@
// ----------------------------------------------------------------------------
#include <algorithm>
#include <memory>
#include <sstream>
#include "ngraph/except.hpp"
#include "ngraph/ops/op.hpp"
#include "ngraph/types/type.hpp"
using namespace ngraph;
using namespace std;
op::RequiresTensorViewArgs::RequiresTensorViewArgs(const std::string& node_type,
const std::vector<std::shared_ptr<Node>>& args)
: Node(node_type, args)
{
for (auto arg : args)
{
if (nullptr == std::dynamic_pointer_cast<const TensorViewType>(arg->get_value_type()))
{
throw ngraph_error("Arguments for node type \"" + node_type +
"\" must be tensor views");
}
}
}
This diff is collapsed.
......@@ -20,8 +20,9 @@ using namespace std;
using namespace ngraph::op;
Parameter::Parameter(const std::shared_ptr<const ValueType>& value_type)
: Node(value_type)
: Node("Parameter", {})
{
set_value_type_checked(value_type);
}
Parameter::Parameter(const ngraph::element::Type& element_type, const Shape& shape)
......@@ -29,10 +30,6 @@ Parameter::Parameter(const ngraph::element::Type& element_type, const Shape& sha
{
}
void Parameter::propagate_types()
{
}
void Parameter::generate_adjoints(autodiff::Adjoints& adjoints, const std::shared_ptr<Node>& delta)
{
}
......@@ -55,15 +55,20 @@ namespace ngraph
/// \brief Constructions a parameter node.
///
/// \param value_type The type of the parameter.
Parameter(const std::shared_ptr<const ValueType>& value_type = nullptr);
Parameter(const std::shared_ptr<const ValueType>& value_type);
/// \brief Constructions a tensor view-typed parameter node.
///
/// \param element_type The element type of the parameter.
/// \param shape The shape of the parameter.
Parameter(const ngraph::element::Type& element_type, const Shape& shape);
std::string description() const override { return "Parameter"; }
virtual void propagate_types() override;
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 0)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Parameter>(get_value_type());
}
};
}
}
......@@ -48,10 +48,17 @@ namespace ngraph
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
Power(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseArithmetic(arg0, arg1)
: BinaryElementwiseArithmetic("Power", arg0, arg1)
{
}
virtual std::string description() const override { return "Power"; }
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 2)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Power>(new_args.at(0), new_args.at(1));
}
};
}
}
......@@ -16,37 +16,19 @@
#include "ngraph/function.hpp"
using namespace std;
using namespace ngraph::op;
void Reduce::propagate_types()
using namespace ngraph;
op::Reduce::Reduce(const std::shared_ptr<Node>& arg_reductee,
const std::shared_ptr<Node>& arg_init,
const std::shared_ptr<Function>& reduction_function,
const AxisSet& reduction_axes)
: RequiresTensorViewArgs("Reduce", {arg_reductee, arg_init})
, m_reduction_function(reduction_function)
, m_reduction_axes(reduction_axes)
{
if (m_arguments.size() != 2)
{
throw ngraph_error("Wrong number of arguments.");
}
auto arg_reductee_type = m_arguments.at(0)->get_value_type();
if (nullptr == arg_reductee_type)
{
throw ngraph_error("Argument to reduce is missing type.");
}
auto arg_reductee_tensor_view_type =
dynamic_pointer_cast<const TensorViewType>(arg_reductee_type);
if (nullptr == arg_reductee_tensor_view_type)
{
throw ngraph_error("Argument to reduce is not a tensor view");
}
auto arg_reductee_tensor_view_type = get_inputs().at(0).get_tensor_view_type();
auto arg_init_type = m_arguments.at(1)->get_value_type();
if (nullptr == arg_init_type)
{
throw ngraph_error("Argument for initial value is missing type.");
}
auto arg_init_tensor_view_type = dynamic_pointer_cast<const TensorViewType>(arg_init_type);
if (nullptr == arg_init_tensor_view_type)
{
throw ngraph_error("Argument for initial value is not a tensor view");
}
auto arg_init_tensor_view_type = get_inputs().at(1).get_tensor_view_type();
if (arg_init_tensor_view_type->get_shape().size() != 0)
{
throw ngraph_error("Argument for initial value is not a scalar");
......@@ -85,18 +67,18 @@ void Reduce::propagate_types()
throw ngraph_error("Reduction function has wrong number of parameters (should be two)");
}
if (*(f_params.at(0)->get_value_type()) != *(arg_init_type))
if (*(f_params.at(0)->get_value_type()) != *(arg_init->get_value_type()))
{
throw ngraph_error("Argument 0 of reduction function has wrong type");
}
if (*(f_params.at(1)->get_value_type()) != *(arg_init_type))
if (*(f_params.at(1)->get_value_type()) != *(arg_init->get_value_type()))
{
throw ngraph_error("Argument 1 of reduction function has wrong type");
}
auto f_result_type = m_reduction_function->get_result_type();
if (*(f_result_type) != *(arg_init_type))
if (*(f_result_type) != *(arg_init->get_value_type()))
{
throw ngraph_error("Return type from reduction function does not match expected");
}
......
......@@ -87,7 +87,7 @@ namespace ngraph
/// | ------- | ----------------------------------------------------- |
/// | NGVM | Fully implemented for scalars, vectors, and matrices. |
class Reduce : public Builtin
class Reduce : public RequiresTensorViewArgs
{
public:
/// \brief Constructs a reduction operation.
......@@ -99,16 +99,17 @@ namespace ngraph
Reduce(const std::shared_ptr<Node>& arg_reductee,
const std::shared_ptr<Node>& arg_init,
const std::shared_ptr<Function>& reduction_function,
const AxisSet& reduction_axes)
: Builtin({arg_reductee, arg_init})
, m_reduction_function(reduction_function)
, m_reduction_axes(reduction_axes)
const AxisSet& reduction_axes);
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 2)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Reduce>(
new_args.at(0), new_args.at(1), m_reduction_function, m_reduction_axes);
}
virtual std::string description() const override { return "Reduce"; }
virtual void propagate_types() override;
/// \return The function to use for reduction.
std::shared_ptr<Function> get_reduction_function() const
{
......
......@@ -50,10 +50,17 @@ namespace ngraph
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
Remainder(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseArithmetic(arg0, arg1)
: BinaryElementwiseArithmetic("Remainder", arg0, arg1)
{
}
virtual std::string description() const override { return "Remainder"; }
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 2)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Remainder>(new_args.at(0), new_args.at(1));
}
};
}
}
......@@ -18,26 +18,16 @@
#include <algorithm>
using namespace std;
using namespace ngraph::op;
using namespace ngraph;
void Reshape::propagate_types()
op::Reshape::Reshape(const std::shared_ptr<Node>& arg,
const AxisVector& input_order,
const Shape& output_shape)
: RequiresTensorViewArgs("Reshape", {arg})
, m_input_order(input_order)
, m_output_shape(output_shape)
{
if (m_arguments.size() != 1)
{
throw ngraph_error("Wrong number of arguments.");
}
auto arg_type = m_arguments.at(0)->get_value_type();
if (nullptr == arg_type)
{
throw ngraph_error("Argument to reshape is missing type.");
}
auto arg_tensor_view_type = dynamic_pointer_cast<const TensorViewType>(arg_type);
if (nullptr == arg_tensor_view_type)
{
throw ngraph_error("Argument to reshape is not a tensor view");
}
auto arg_tensor_view_type = get_inputs().at(0).get_tensor_view_type();
auto arg_shape = arg_tensor_view_type->get_shape();
auto arg_rank = arg_shape.size();
......@@ -79,8 +69,8 @@ void Reshape::propagate_types()
make_shared<TensorViewType>(arg_tensor_view_type->get_element_type(), m_output_shape));
}
void ngraph::op::Reshape::generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta)
void op::Reshape::generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta)
{
auto x = m_arguments[0];
auto x_type = x->get_value_type();
......
......@@ -59,7 +59,7 @@ namespace ngraph
/// | Backend | Status |
/// | ------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
/// | NGVM | Fully implemented for scalars, vectors, and matrices. Implemented for other shapes only when there is no reordering of the input axes, i.e. `input_order` is \f$(0,\dots,n-1)\f$. |
class Reshape : public Builtin
class Reshape : public RequiresTensorViewArgs
{
public:
/// \brief Constructs a reshape operation.
......@@ -71,16 +71,16 @@ namespace ngraph
/// be of the form \f$(b_0,\dots,b_{j-1})\f$ where \f$\Pi(a_i) = \Pi(b_i)\f$.
Reshape(const std::shared_ptr<Node>& arg,
const AxisVector& input_order,
const Shape& output_shape)
: Builtin({arg})
, m_input_order(input_order)
, m_output_shape(output_shape)
const Shape& output_shape);
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 1)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Reshape>(new_args.at(0), m_input_order, m_output_shape);
}
virtual std::string description() const override { return "Reshape"; }
virtual void propagate_types() override;
/// \return The order in which to iterate over input axes.
const AxisVector& get_input_order() const { return m_input_order; }
/// \return The shape of the output tensor.
......
......@@ -19,25 +19,16 @@
using namespace std;
using namespace ngraph;
using namespace ngraph::op;
void Select::propagate_types()
op::Select::Select(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1,
const std::shared_ptr<Node>& arg2)
: RequiresTensorViewArgs("Select", Nodes{arg0, arg1, arg2})
{
if (m_arguments.size() != 3)
{
throw ngraph_error("Wrong number of arguments.");
}
auto arg0_tensor_type = get_inputs().at(0).get_tensor_view_type();
auto arg1_tensor_type = get_inputs().at(1).get_tensor_view_type();
auto arg2_tensor_type = get_inputs().at(2).get_tensor_view_type();
auto arg0_tensor_type =
dynamic_pointer_cast<const TensorViewType>(m_arguments.at(0)->get_value_type());
auto arg1_tensor_type =
dynamic_pointer_cast<const TensorViewType>(m_arguments.at(1)->get_value_type());
auto arg2_tensor_type =
dynamic_pointer_cast<const TensorViewType>(m_arguments.at(2)->get_value_type());
if (nullptr == arg0_tensor_type || nullptr == arg1_tensor_type || nullptr == arg2_tensor_type)
{
throw ngraph_error("Arguments must be tensor views");
}
if (arg0_tensor_type->get_element_type() != element::Bool::element_type())
{
throw ngraph_error("Argument 0 for arithmetic operators must have boolean element type");
......
......@@ -41,7 +41,7 @@ namespace ngraph
/// | Backend | Status |
/// | ------- | ------------------ |
/// | NGVM | Fully implemented. |
class Select : public Builtin
class Select : public RequiresTensorViewArgs
{
public:
/// \brief Constructs a selection operation.
......@@ -51,12 +51,15 @@ namespace ngraph
/// \param arg2 Node that produces the third input tensor.
Select(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1,
const std::shared_ptr<Node>& arg2)
: Builtin(Nodes{arg0, arg1, arg2})
const std::shared_ptr<Node>& arg2);
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 3)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Select>(new_args.at(0), new_args.at(1), new_args.at(2));
}
virtual std::string description() const override { return "Select"; }
virtual void propagate_types() override;
};
}
}
......@@ -48,11 +48,17 @@ namespace ngraph
///
/// \param arg Node that produces the input tensor.
Sign(const std::shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic(arg)
: UnaryElementwiseArithmetic("Sign", arg)
{
}
virtual std::string description() const override { return "Sign"; }
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 1)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Sign>(new_args.at(0));
}
};
}
}
......@@ -46,11 +46,17 @@ namespace ngraph
///
/// \param arg Node that produces the input tensor.
Sin(const std::shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic(arg)
: UnaryElementwiseArithmetic("Sin", arg)
{
}
virtual std::string description() const override { return "Sin"; }
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 1)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Sin>(new_args.at(0));
}
};
}
}
......@@ -46,11 +46,17 @@ namespace ngraph
///
/// \param arg Node that produces the input tensor.
Sinh(const std::shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic(arg)
: UnaryElementwiseArithmetic("Sinh", arg)
{
}
virtual std::string description() const override { return "Sinh"; }
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 1)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Sinh>(new_args.at(0));
}
};
}
}
......@@ -15,25 +15,34 @@
#include "ngraph/ops/slice.hpp"
using namespace std;
using namespace ngraph::op;
using namespace ngraph;
void Slice::propagate_types()
op::Slice::Slice(const std::shared_ptr<Node>& arg,
const Coordinate& lower_bounds,
const Coordinate& upper_bounds,
const Shape& step)
: RequiresTensorViewArgs("Slice", {arg})
, m_lower_bounds(lower_bounds)
, m_upper_bounds(upper_bounds)
, m_step(step)
{
if (m_arguments.size() != 1)
{
throw ngraph_error("Wrong number of arguments.");
}
check_args();
}
auto arg_type = m_arguments.at(0)->get_value_type();
if (nullptr == arg_type)
{
throw ngraph_error("Argument to slice is missing type.");
}
auto arg_tensor_view_type = dynamic_pointer_cast<const TensorViewType>(arg_type);
if (nullptr == arg_tensor_view_type)
{
throw ngraph_error("Argument to slice is not a tensor view");
}
op::Slice::Slice(const std::shared_ptr<Node>& arg,
const Coordinate& lower_bounds,
const Coordinate& upper_bounds)
: RequiresTensorViewArgs("Slice", {arg})
, m_lower_bounds(lower_bounds)
, m_upper_bounds(upper_bounds)
, m_step(Shape(lower_bounds.size(), 1))
{
check_args();
}
void op::Slice::check_args()
{
auto arg_tensor_view_type = get_inputs().at(0).get_tensor_view_type();
auto& arg_shape = arg_tensor_view_type->get_shape();
if (m_lower_bounds.size() != arg_shape.size())
......
......@@ -52,7 +52,7 @@ namespace ngraph
/// | Backend | Status |
/// | ------- | ----------------------------------------------- |
/// | NGVM | Implemented for scalars, matrices, and vectors. |
class Slice : public Builtin
class Slice : public RequiresTensorViewArgs
{
public:
/// \brief Constructs a tensor slice operation.
......@@ -65,13 +65,7 @@ namespace ngraph
Slice(const std::shared_ptr<Node>& arg,
const Coordinate& lower_bounds,
const Coordinate& upper_bounds,
const Shape& step)
: Builtin({arg})
, m_lower_bounds(lower_bounds)
, m_upper_bounds(upper_bounds)
, m_step(step)
{
}
const Shape& step);
/// \brief Constructs a tensor slice operation with unit step; i.e., every element inside the bounding box will be copied to the output slice.
///
......@@ -80,17 +74,17 @@ namespace ngraph
/// \param upper_bounds The axiswise upper bounds of the slice (exclusive).
Slice(const std::shared_ptr<Node>& arg,
const Coordinate& lower_bounds,
const Coordinate& upper_bounds)
: Builtin({arg})
, m_lower_bounds(lower_bounds)
, m_upper_bounds(upper_bounds)
, m_step(Shape(lower_bounds.size(), 1))
const Coordinate& upper_bounds);
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 1)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Slice>(
new_args.at(0), m_lower_bounds, m_upper_bounds, m_step);
}
virtual std::string description() const override { return "Slice"; }
virtual void propagate_types() override;
/// \return The inclusive lower-bound coordinates.
const Coordinate& get_lower_bounds() const { return m_lower_bounds; }
/// \return The exclusive upper-bound coordinates.
......@@ -98,6 +92,8 @@ namespace ngraph
/// \return The slicing step.
const Shape& get_step() const { return m_step; }
protected:
void check_args();
const Coordinate m_lower_bounds;
const Coordinate m_upper_bounds;
const Shape m_step;
......
......@@ -48,14 +48,20 @@ namespace ngraph
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
Subtract(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseArithmetic(arg0, arg1)
: BinaryElementwiseArithmetic("Subtract", arg0, arg1)
{
}
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 2)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Subtract>(new_args.at(0), new_args.at(1));
}
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta) override;
virtual std::string description() const override { return "Subtract"; }
};
}
inline std::shared_ptr<ngraph::Node> operator-(const std::shared_ptr<ngraph::Node> arg0,
......
......@@ -16,26 +16,13 @@
#include "ngraph/function.hpp"
using namespace std;
using namespace ngraph::op;
using namespace ngraph;
void Sum::propagate_types()
op::Sum::Sum(const std::shared_ptr<Node>& arg, const AxisSet& reduction_axes)
: RequiresTensorViewArgs("Sum", {arg})
, m_reduction_axes(reduction_axes)
{
if (m_arguments.size() != 1)
{
throw ngraph_error("Wrong number of arguments.");
}
auto arg_type = m_arguments.at(0)->get_value_type();
if (nullptr == arg_type)
{
throw ngraph_error("Argument to sum is missing type.");
}
auto arg_tensor_view_type = dynamic_pointer_cast<const TensorViewType>(arg_type);
if (nullptr == arg_tensor_view_type)
{
throw ngraph_error("Argument to sum is not a tensor view");
}
auto arg_tensor_view_type = get_inputs().at(0).get_tensor_view_type();
auto& arg_element_type = arg_tensor_view_type->get_element_type();
if (arg_element_type == element::Bool::element_type())
{
......
......@@ -80,22 +80,23 @@ namespace ngraph
/// | Backend | Status |
/// | ------- | ----------------------------------------------------- |
/// | NGVM | Fully implemented for scalars, vectors, and matrices. |
class Sum : public Builtin
class Sum : public RequiresTensorViewArgs
{
public:
/// \brief Constructs a summation operation.
///
/// \param arg The tensor view to be summed.
/// \param reduction_axes The axis positions (0-based) to be eliminated.
Sum(const std::shared_ptr<Node>& arg, const AxisSet& reduction_axes)
: Builtin({arg})
, m_reduction_axes(reduction_axes)
Sum(const std::shared_ptr<Node>& arg, const AxisSet& reduction_axes);
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 1)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Sum>(new_args.at(0), m_reduction_axes);
}
virtual std::string description() const override { return "Sum"; }
virtual void propagate_types() override;
/// \return The axis positions (0-based) to be eliminated through summation.
const AxisSet& get_reduction_axes() const { return m_reduction_axes; }
protected:
......
......@@ -46,11 +46,17 @@ namespace ngraph
///
/// \param arg Node that produces the input tensor.
Tan(const std::shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic(arg)
: UnaryElementwiseArithmetic("Tan", arg)
{
}
virtual std::string description() const override { return "Tan"; }
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 1)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Tan>(new_args.at(0));
}
};
}
}
......@@ -46,11 +46,17 @@ namespace ngraph
///
/// \param arg Node that produces the input tensor.
Tanh(const std::shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic(arg)
: UnaryElementwiseArithmetic("Tanh", arg)
{
}
virtual std::string description() const override { return "Tanh"; }
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 1)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Tanh>(new_args.at(0));
}
};
}
}
......@@ -18,9 +18,10 @@
#include "ngraph/ops/tuple.hpp"
using namespace std;
using namespace ngraph::op;
using namespace ngraph;
void Tuple::propagate_types()
op::Tuple::Tuple(const Nodes& args)
: Node("Tuple", args)
{
vector<shared_ptr<const ValueType>> element_types;
for (auto argument : m_arguments)
......
......@@ -14,7 +14,7 @@
#pragma once
#include "ngraph/ops/op.hpp"
#include "ngraph/node.hpp"
namespace ngraph
{
......@@ -39,19 +39,19 @@ namespace ngraph
/// | Backend | Status |
/// | ------- | ------------------ |
/// | NGVM | Fully implemented. |
class Tuple : public Builtin
class Tuple : public ngraph::Node
{
public:
/// \brief Constructs a tuple construction operation.
///
/// \param args The nodes that produce the elements of the constructed tuple.
Tuple(const Nodes& args)
: Builtin(args)
Tuple(const Nodes& args);
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
return std::make_shared<Tuple>(new_args);
}
virtual std::string description() const override { return "Tuple"; }
virtual void propagate_types() override;
};
}
}
......@@ -18,24 +18,16 @@
using namespace std;
using namespace ngraph;
using namespace ngraph::op;
void UnaryElementwiseBuiltin::propagate_types()
op::UnaryElementwise::UnaryElementwise(
const std::string& node_type,
std::function<const element::Type&(const element::Type&)> element_type_function,
const std::shared_ptr<Node>& arg)
: RequiresTensorViewArgs(node_type, Nodes{arg})
{
if (m_arguments.size() != 1)
{
throw ngraph_error("Wrong number of arguments.");
}
auto arg_tensor_type =
dynamic_pointer_cast<const TensorViewType>(m_arguments.at(0)->get_value_type());
if (nullptr == arg_tensor_type)
{
throw ngraph_error("Argument must be tensor view");
}
auto arg_tensor_type = get_inputs().at(0).get_tensor_view_type();
const element::Type& result_element_type =
propagate_element_types(arg_tensor_type->get_element_type());
element_type_function(arg_tensor_type->get_element_type());
set_value_type_checked(
make_shared<TensorViewType>(result_element_type, arg_tensor_type->get_shape()));
......
......@@ -15,15 +15,21 @@
#include "ngraph/ops/op.hpp"
using namespace ngraph;
using namespace ngraph::op;
const element::Type&
UnaryElementwiseArithmetic::propagate_element_types(const element::Type& arg_element_type) const
{
if (arg_element_type == element::Bool::element_type())
{
throw ngraph_error("Operands for arithmetic operators must have numeric element type");
}
op::UnaryElementwiseArithmetic::UnaryElementwiseArithmetic(const std::string& node_type,
const std::shared_ptr<Node>& arg)
: UnaryElementwise(
node_type,
[](const ngraph::element::Type& arg_element_type) -> const ngraph::element::Type& {
if (arg_element_type == element::Bool::element_type())
{
throw ngraph_error(
"Operands for arithmetic operators must have numeric element "
"type");
}
return arg_element_type;
return arg_element_type;
},
arg)
{
}
......@@ -12,41 +12,48 @@
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/pass/assign_tensors.hpp"
#pragma once
#include <exception>
#include <sstream>
#include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/propagate_types.hpp"
#include "ngraph/descriptor/output.hpp"
#include "ngraph/pass/pass.hpp"
using namespace std;
using namespace ngraph;
bool pass::AssignTensors::run_on_call_graph(list<std::shared_ptr<Node>>& nodes)
namespace ngraph
{
for (shared_ptr<Node> node : nodes)
namespace pass
{
try
template <typename LT>
class AssignLayout : public CallGraphPass
{
// We need to set the nodes is_output state prior to call assign_tensors
// so that the output state can be passes to the constructed tensors.
if (node == get_state().get_functions().at(0)->get_result())
public:
virtual bool run_on_call_graph(std::list<std::shared_ptr<Node>>& nodes) override
{
node->set_is_output();
for (const std::shared_ptr<Node>& node : nodes)
{
try
{
for (const descriptor::Output& output : node->get_outputs())
{
auto tv = output.get_tensor_view();
if (nullptr == tv->get_tensor_view_layout())
{
auto layout = std::make_shared<LT>(*tv);
tv->set_tensor_view_layout(layout);
}
}
}
catch (const std::exception& e)
{
std::stringstream ss;
ss << "Error with node " << *node << ": ";
ss << e.what();
throw std::invalid_argument(ss.str());
}
}
return false;
}
node->assign_tensors();
}
catch (exception& e)
{
stringstream ss;
ss << "Error with node " << *node << ": ";
ss << e.what();
throw invalid_argument(ss.str());
}
};
}
return false;
}
......@@ -20,7 +20,6 @@
#include "ngraph/descriptor/output.hpp"
#include "ngraph/log.hpp"
#include "ngraph/node.hpp"
#include "ngraph/pass/assign_tensors.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/util.hpp"
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <algorithm>
#include "call_frame.hpp"
using namespace std;
using namespace ngraph::runtime::cpu;
CallFrame::CallFrame(EntryPoint compiled_function,
size_t n_outputs,
size_t n_inputs,
const TensorViewPtrs& temps,
const std::vector<std::shared_ptr<CallFrame>>& callees)
: m_n_outputs(n_outputs)
, m_n_inputs(n_inputs)
, m_tensor_views(n_outputs + n_inputs + temps.size())
, m_compiled_function(compiled_function)
, m_callees(callees)
{
copy(temps.begin(), temps.end(), m_tensor_views.begin() + m_n_outputs + m_n_inputs);
}
void CallFrame::tensor_call(
const std::vector<std::shared_ptr<ngraph::runtime::TensorView>>& inputs,
const std::vector<std::shared_ptr<ngraph::runtime::TensorView>>& outputs)
{
copy(outputs.begin(), outputs.end(), m_tensor_views.begin());
copy(inputs.begin(), inputs.end(), m_tensor_views.begin() + m_n_outputs);
// Invoke compiled computation
m_compiled_function(this, m_tensor_views, m_callees);
// Don't hold onto inputs/outputs
fill_n(m_tensor_views.begin(), m_n_outputs + m_n_inputs, nullptr);
}
void CallFrame::operator()(const std::vector<std::shared_ptr<ngraph::runtime::Value>>& arguments,
const std::vector<std::shared_ptr<ngraph::runtime::Value>>& results)
{
// TODO: Check types of args and result
std::vector<std::shared_ptr<ngraph::runtime::TensorView>> inputs;
for (auto argument : arguments)
{
argument->collect_tensor_views(inputs, argument);
}
std::vector<std::shared_ptr<ngraph::runtime::TensorView>> outputs;
for (auto result : results)
{
result->collect_tensor_views(outputs, result);
}
tensor_call(inputs, outputs);
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <functional>
#include <memory>
#include <vector>
#include "ngraph/function.hpp"
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
class PrimaryTensorView;
namespace cpu
{
class CallFrame;
using EntryPoint = std::function<void(ngraph::runtime::cpu::CallFrame*,
ngraph::runtime::TensorViewPtrs&,
const std::vector<std::shared_ptr<CallFrame>>&)>;
// Compile and execute graphs
class CallFrame : public ngraph::runtime::CallFrame
{
public:
CallFrame(EntryPoint compiled_function,
size_t n_outputs,
size_t n_inputs,
const TensorViewPtrs& temps,
const std::vector<std::shared_ptr<CallFrame>>& callees);
/// @brief Invoke the function with values matching the signature of the function.
///
/// Tuples will be expanded into their tensor views to build the call frame.
void
operator()(const std::vector<std::shared_ptr<ngraph::runtime::Value>>& inputs,
const std::vector<std::shared_ptr<ngraph::runtime::Value>>& outputs);
/// @brief Invoke the function with tuples pre-expanded to their underlying tensor views.
void tensor_call(const TensorViewPtrs& inputs, const TensorViewPtrs& outputs);
void set_return() { m_return = true; }
std::shared_ptr<TensorView> get_tensor_view(size_t i) { return m_tensor_views[i]; }
template <typename ET>
ParameterizedTensorView<ET>* get_parameterized_tensor_view(size_t i)
{
return m_tensor_views[i]->get_parameterized_tensor_view<ET>();
}
template <typename ET>
typename ET::type* get_tensor_view_data(size_t i)
{
return &get_parameterized_tensor_view<ET>(i)->get_vector()[0];
}
protected:
size_t m_n_outputs;
size_t m_n_inputs;
TensorViewPtrs m_tensor_views;
bool m_return;
EntryPoint m_compiled_function;
std::vector<std::shared_ptr<CallFrame>> m_callees;
};
}
}
}
......@@ -12,22 +12,13 @@
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/cpu/cpu_backend.hpp"
#include "ngraph/runtime/external_function.hpp"
#include "ngraph/pass/pass.hpp"
using namespace ngraph::runtime::cpu;
namespace ngraph
std::shared_ptr<ngraph::runtime::CallFrame>
CPUBackend::make_call_frame(const std::shared_ptr<ExternalFunction>& external_function)
{
namespace pass
{
class AssignTensors;
}
return external_function->make_call_frame();
}
class ngraph::pass::AssignTensors : public CallGraphPass
{
public:
virtual bool run_on_call_graph(std::list<std::shared_ptr<Node>>& nodes) override;
private:
};
......@@ -12,29 +12,22 @@
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <sstream>
#pragma once
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/propagate_types.hpp"
#include "ngraph/runtime/backend.hpp"
using namespace std;
using namespace ngraph;
bool pass::PropagateTypes::run_on_call_graph(list<shared_ptr<Node>>& nodes)
namespace ngraph
{
for (shared_ptr<Node> node : nodes)
namespace runtime
{
try
{
node->propagate_types();
}
catch (exception& e)
namespace cpu
{
stringstream ss;
ss << "Error with node " << *node << ": ";
ss << e.what();
throw invalid_argument(ss.str());
class CPUBackend : public Backend
{
public:
virtual std::shared_ptr<ngraph::runtime::CallFrame> make_call_frame(
const std::shared_ptr<ngraph::runtime::ExternalFunction>& external_function);
};
}
}
return false;
}
......@@ -12,22 +12,4 @@
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/pass/pass.hpp"
namespace ngraph
{
namespace pass
{
class PropagateTypes;
}
}
class ngraph::pass::PropagateTypes : public CallGraphPass
{
public:
virtual bool run_on_call_graph(std::list<std::shared_ptr<Node>>&) override;
private:
};
#include "ngraph/runtime/cpu/cpu_kernels.hpp"
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/types/element_type.hpp"
// CBLAS types and wrappers
namespace cblas
{
enum class Layout
{
RowMajor = 101,
ColMajor = 102
};
enum class Transpose
{
None = 111,
Transpose = 112,
ConjTrans = 113
};
enum class UpperLower
{
Upper = 121,
Lower = 122
};
enum class Diag
{
NonUnit = 131,
Unit = 132
};
enum class Side
{
Left = 141,
Right = 142
};
enum class Storage
{
Packed = 151
};
enum class Ident
{
AMatrix = 161,
BMatrix = 162
};
enum class Offset
{
RowOffset = 171,
ColOffset = 172,
FixOffset = 173
};
extern "C" {
void cblas_sgemm(const Layout layout,
const Transpose TransA,
const Transpose TransB,
const ngraph::element::Int64::type M,
const ngraph::element::Int64::type N,
const ngraph::element::Int64::type K,
const ngraph::element::Float32::type alpha,
const ngraph::element::Float32::type* A,
const ngraph::element::Int64::type lda,
const ngraph::element::Float32::type* B,
const ngraph::element::Int64::type ldb,
const ngraph::element::Float32::type beta,
ngraph::element::Float32::type* C,
const ngraph::element::Int64::type ldc);
}
}
namespace mkl
{
extern "C" {
void MKL_Somatcopy(char ordering,
char trans,
size_t rows,
size_t cols,
const ngraph::element::Float32::type alpha,
const ngraph::element::Float32::type* A,
size_t lda,
ngraph::element::Float32::type* B,
size_t ldb);
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <memory>
#include "ngraph/runtime/cpu/cpu_backend.hpp"
#include "ngraph/runtime/cpu/cpu_manager.hpp"
#include "ngraph/runtime/cpu/external_function.hpp"
using namespace ngraph::runtime::cpu;
std::shared_ptr<ngraph::runtime::Backend> CPUManager::allocate_backend()
{
return std::make_shared<CPUBackend>();
}
std::shared_ptr<ngraph::runtime::ExternalFunction>
CPUManager::compile(const std::shared_ptr<ngraph::Function>& fun)
{
return std::make_shared<ExternalFunction>(fun);
}
ngraph::runtime::Manager::Factory CPUManager::factory = ngraph::runtime::Manager::register_factory(
"CPU", [](const std::string& name) -> std::shared_ptr<ngraph::runtime::Manager> {
return std::make_shared<CPUManager>();
});
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <memory>
#include "ngraph/codegen/compiler.hpp"
#include "ngraph/runtime/manager.hpp"
namespace ngraph
{
class Function;
namespace runtime
{
class ExternalFunction;
namespace cpu
{
/// @brief Transformer for the interpreted backend
class CPUManager : public Manager
{
protected:
ngraph::codegen::execution_state exec_state;
public:
virtual std::shared_ptr<Backend> allocate_backend() override;
virtual std::shared_ptr<ngraph::runtime::ExternalFunction>
compile(const std::shared_ptr<ngraph::Function>& fun) override;
static Factory factory;
};
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <memory>
#include <Eigen/Dense>
#include "ngraph/descriptor/layout/dense_tensor_view_layout.hpp"
#include "ngraph/runtime/cpu/call_frame.hpp"
#include "ngraph/runtime/tensor_view_info.hpp"
namespace ngraph
{
namespace runtime
{
class TensorViewInfo;
namespace cpu
{
class CallFrame;
namespace eigen
{
using DynamicStrides = Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic>;
using VectorStrides = Eigen::Stride<Eigen::Dynamic, 1>;
template <typename ET>
using DynamicArray =
Eigen::Array<typename ET::type, Eigen::Dynamic, Eigen::Dynamic>;
template <typename ET>
using EigenArrayBase = Eigen::Map<DynamicArray<ET>, 0, DynamicStrides>;
template <typename ET>
using DynamicMatrix = Eigen::
Matrix<typename ET::type, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
template <typename ET>
using EigenMatrixBase = Eigen::Map<DynamicMatrix<ET>, 0, DynamicStrides>;
template <typename ET>
using DynamicVector = Eigen::Matrix<typename ET::type, Eigen::Dynamic, 1>;
template <typename ET>
using EigenVectorBase = Eigen::Map<DynamicVector<ET>, 0, VectorStrides>;
namespace fmt
{
/// @brief vector format for Eigen wrappers.
class V
{
public:
V(const TensorViewInfo& tensor_view_info)
: l0(tensor_view_info
.get_layout<
ngraph::descriptor::layout::DenseTensorViewLayout>()
->get_size())
{
}
V(size_t s)
: l0(s)
{
}
public:
size_t l0;
size_t l1{1};
size_t s0{1};
size_t s1{1};
};
class M
{
M(const std::shared_ptr<ngraph::descriptor::layout::DenseTensorViewLayout>&
layout)
: M(layout->get_shape(), layout->get_strides())
{
}
public:
M(const Shape& shape, const Strides& strides)
: l0(shape.at(0))
, l1(shape.at(1))
, s0(strides.at(0))
, s1(strides.at(1))
{
}
M(const TensorViewInfo& tensor_view_info)
: M(tensor_view_info.get_layout<
ngraph::descriptor::layout::DenseTensorViewLayout>())
{
}
public:
size_t l0;
size_t l1;
size_t s0;
size_t s1;
};
}
// ET element type
// FMT array format (fmt::V for vector, etc.)
// BASE select array/matrix
template <typename ET,
typename FMT,
typename BASE,
typename STRIDES = DynamicStrides>
class EigenWrapper : public BASE
{
using base = BASE;
public:
EigenWrapper(typename ET::type* t, const FMT& fmt)
: base(t, fmt.l0, fmt.l1, STRIDES(fmt.s0, fmt.s1))
{
}
EigenWrapper(
typename ET::type* t,
const std::shared_ptr<ngraph::descriptor::layout::DenseTensorViewLayout>&
layout)
: base(t, layout->get_size(), 1, DynamicStrides(1, 1))
{
}
EigenWrapper(CallFrame* call_frame, const TensorViewInfo& tensor_view_info)
: EigenWrapper(
call_frame->get_tensor_view_data<ET>(tensor_view_info.get_index()),
FMT(tensor_view_info))
{
}
template <typename U>
EigenWrapper& operator=(const U& other)
{
this->base::operator=(other);
return *this;
}
};
template <typename ET, typename FMT = fmt::V>
using EigenArray1d = EigenWrapper<ET, FMT, EigenArrayBase<ET>>;
template <typename ET, typename FMT = fmt::M>
using EigenArray2d = EigenWrapper<ET, FMT, EigenArrayBase<ET>>;
template <typename ET, typename FMT = fmt::M>
using EigenMatrix = EigenWrapper<ET, FMT, EigenMatrixBase<ET>>;
template <typename ET, typename FMT = fmt::V>
using EigenVector = EigenWrapper<ET, FMT, EigenVectorBase<ET>, VectorStrides>;
}
}
}
}
This diff is collapsed.
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <string>
#include <vector>
#include "ngraph/node.hpp"
#include "ngraph/runtime/cpu/external_function.hpp"
#include "ngraph/runtime/tensor_view_info.hpp"
#define EMITTER_DECL(E) \
E(const ngraph::Node* n, \
ExternalFunction* ef, \
FunctionMap& function_map, \
const std::vector<TensorViewInfo>& inputs, \
const std::vector<TensorViewInfo>& outputs)
namespace ngraph
{
namespace runtime
{
namespace cpu
{
class Emitter
{
protected:
std::string TU;
public:
Emitter()
: TU("")
{
}
std::string& GetTU() { return TU; }
void EMITTER_DECL(EmitNop);
void EMITTER_DECL(EmitAdd);
void EMITTER_DECL(EmitDot);
void EMITTER_DECL(EmitMultiply);
void EMITTER_DECL(EmitGetTupleElement);
void EMITTER_DECL(EmitTuple);
void EMITTER_DECL(EmitAbs);
void EMITTER_DECL(EmitConcat);
void EMITTER_DECL(EmitDivide);
void EMITTER_DECL(EmitEqual);
void EMITTER_DECL(EmitGreater);
void EMITTER_DECL(EmitGreaterEq);
void EMITTER_DECL(EmitLess);
void EMITTER_DECL(EmitLessEq);
void EMITTER_DECL(EmitLog);
void EMITTER_DECL(EmitMaximum);
void EMITTER_DECL(EmitMinimum);
void EMITTER_DECL(EmitNegative);
void EMITTER_DECL(EmitNotEqual);
void EMITTER_DECL(EmitSelect);
void EMITTER_DECL(EmitSubtract);
void EMITTER_DECL(EmitParameterizedConstantBool);
void EMITTER_DECL(EmitParameterizedConstantFloat32);
void EMITTER_DECL(EmitParameterizedConstantInt8);
void EMITTER_DECL(EmitParameterizedConstantInt32);
void EMITTER_DECL(EmitParameterizedConstantInt64);
void EMITTER_DECL(EmitParameterizedConstantUInt8);
void EMITTER_DECL(EmitParameterizedConstantUInt32);
void EMITTER_DECL(EmitParameterizedConstantUInt64);
void EMITTER_DECL(EmitBroadcast);
void EMITTER_DECL(EmitConvert);
void EMITTER_DECL(EmitConstant);
void EMITTER_DECL(EmitReshape);
void EMITTER_DECL(EmitFunctionCall);
void EMITTER_DECL(EmitReduce);
void EMITTER_DECL(EmitSign);
void EMITTER_DECL(EmitSlice);
void EMITTER_DECL(EmitSum);
void EMITTER_DECL(EmitExp);
void EMITTER_DECL(EmitSin);
void EMITTER_DECL(EmitSinh);
void EMITTER_DECL(EmitCos);
void EMITTER_DECL(EmitCosh);
void EMITTER_DECL(EmitTan);
void EMITTER_DECL(EmitTanh);
void EMITTER_DECL(EmitAsin);
void EMITTER_DECL(EmitAcos);
void EMITTER_DECL(EmitAtan);
};
}
}
}
This diff is collapsed.
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <functional>
#include <memory>
#include <typeindex>
#include <typeinfo>
#include <unordered_map>
#include "ngraph/codegen/compiler.hpp"
#include "ngraph/function.hpp"
#include "ngraph/runtime/external_function.hpp"
#include "ngraph/runtime/tensor_view_info.hpp"
namespace ngraph
{
namespace runtime
{
namespace cpu
{
class ExternalFunction;
class Emitter;
class CallFrame;
using FunctionMap =
std::unordered_map<std::shared_ptr<Function>, std::shared_ptr<ExternalFunction>>;
using OpFunction = std::function<void(Emitter*,
const ngraph::Node*,
ExternalFunction*,
FunctionMap&,
const std::vector<TensorViewInfo>& inputs,
const std::vector<TensorViewInfo>& outputs)>;
using OpMap = std::unordered_map<std::type_index, OpFunction>;
using EntryPoint = std::function<void(
ngraph::runtime::cpu::CallFrame*,
ngraph::runtime::TensorViewPtrs&,
const std::vector<std::shared_ptr<ngraph::runtime::cpu::CallFrame>>&)>;
class ExternalFunction : public ngraph::runtime::ExternalFunction
{
public:
ExternalFunction(const std::shared_ptr<ngraph::Function>& function,
bool release_function = true);
std::shared_ptr<ngraph::runtime::CallFrame> make_call_frame();
std::vector<std::shared_ptr<CallFrame>>& get_callees() { return callees; }
protected:
void compile(FunctionMap& function_map);
size_t m_n_inputs;
size_t m_n_outputs;
ngraph::descriptor::TensorViewPtrs m_temp_views;
EntryPoint compiled_function;
std::vector<std::shared_ptr<CallFrame>> callees;
};
}
}
}
......@@ -12,13 +12,63 @@
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <dlfcn.h>
#include <iostream>
#include <sstream>
#include <string>
#include "ngraph/except.hpp"
#include "ngraph/runtime/manager.hpp"
#include "ngraph/util.hpp"
using namespace ngraph::runtime;
bool Manager::m_is_factory_map_initialized = false;
std::shared_ptr<std::vector<void*>> Manager::m_plugin_lib_handles =
std::make_shared<std::vector<void*>>(std::vector<void*>());
void Manager::load_plugins(const std::string& runtime_plugin_libs)
{
std::vector<std::string> plugin_lib_paths = ngraph::split(runtime_plugin_libs, ':', false);
for (auto plugin_lib_path : plugin_lib_paths)
{
if (plugin_lib_path.size() > 0)
{
void* lib_handle = dlopen(plugin_lib_path.c_str(), RTLD_NOW);
if (lib_handle)
{
Manager::m_plugin_lib_handles->push_back(lib_handle);
}
else
{
throw ngraph_error("Cannot open library " + plugin_lib_path);
}
}
}
}
// TODO: Should call this function after plugin is not needed anymore.
void Manager::close_plugins()
{
for (auto lib_handle : *Manager::m_plugin_lib_handles)
{
dlclose(lib_handle);
}
Manager::m_plugin_lib_handles->clear();
}
Manager::FactoryMap& Manager::get_factory_map()
{
// Stores Manager Factories
static FactoryMap factory_map;
// Try to load runtime plugins
if (!Manager::m_is_factory_map_initialized)
{
Manager::load_plugins(RUNTIME_PLUGIN_LIBS);
Manager::m_is_factory_map_initialized = true;
}
return factory_map;
}
......@@ -27,7 +77,7 @@ std::shared_ptr<Manager> Manager::get(const std::string& name)
return get_factory_map().at(name)(name);
}
Manager::Factory Manager::register_factory(std::string name, Factory factory)
Manager::Factory Manager::register_factory(const std::string& name, Factory factory)
{
get_factory_map()[name] = factory;
return factory;
......
......@@ -18,6 +18,7 @@
#include <map>
#include <memory>
#include <string>
#include <vector>
namespace ngraph
{
......@@ -46,13 +47,23 @@ namespace ngraph
compile(const std::shared_ptr<ngraph::Function>& fun) = 0;
using Factory = std::function<std::shared_ptr<Manager>(const std::string&)>;
using FactoryMap = std::map<std::string, Factory>;
static FactoryMap& get_factory_map();
static std::shared_ptr<Manager> get(const std::string& name);
static Factory register_factory(std::string name, Factory factory);
static Factory register_factory(const std::string& name, Factory factory);
private:
static void load_plugins(const std::string& runtime_plugin_libs);
static void close_plugins();
static std::shared_ptr<std::vector<void*>> m_plugin_lib_handles;
static bool m_is_factory_map_initialized;
using FactoryMap = std::map<std::string, Factory>;
static FactoryMap& get_factory_map();
};
}
}
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment