Unverified Commit f5768063 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Merge pull request #222 from NervanaSystems/jmenon/cpu_kernels

CPU Backend: More ops and kernels
parents 69a2d4aa 792d3328
...@@ -99,18 +99,21 @@ include_directories( ...@@ -99,18 +99,21 @@ include_directories(
"${EIGEN_INCLUDE_DIR}" "${EIGEN_INCLUDE_DIR}"
) )
if(LLVM_INCLUDE_DIR) if (NGRAPH_CPU_ENABLE AND LLVM_INCLUDE_DIR AND
MKLDNN_INCLUDE_DIR)
find_package(ZLIB REQUIRED) find_package(ZLIB REQUIRED)
include_directories(SYSTEM ${LLVM_INCLUDE_DIR}) include_directories(SYSTEM ${LLVM_INCLUDE_DIR} ${MKLDNN_INCLUDE_DIR})
link_directories(${LLVM_LIB_DIR}) link_directories(${LLVM_LIB_DIR} ${MKLDNN_LIB_DIR})
# Add sources for the CPU backend # Add sources for the CPU backend
# and all its dependencies # and all its dependencies
set(SRC ${SRC} set(SRC ${SRC}
codegen/compiler.cpp codegen/compiler.cpp
runtime/cpu/call_frame.cpp runtime/cpu/call_frame.cpp
runtime/cpu/cpu_manager.cpp
runtime/cpu/cpu_backend.cpp runtime/cpu/cpu_backend.cpp
runtime/cpu/cpu_manager.cpp
runtime/cpu/cpu_kernels.cpp
runtime/cpu/emitter.cpp runtime/cpu/emitter.cpp
runtime/cpu/external_function.cpp runtime/cpu/external_function.cpp
) )
...@@ -129,7 +132,7 @@ endif() ...@@ -129,7 +132,7 @@ endif()
add_library(ngraph SHARED ${SRC}) add_library(ngraph SHARED ${SRC})
target_include_directories(ngraph PUBLIC "${NGRAPH_INCLUDE_PATH}") target_include_directories(ngraph PUBLIC "${NGRAPH_INCLUDE_PATH}")
if(LLVM_LINK_LIBS) if(NGRAPH_CPU_ENABLE AND LLVM_LINK_LIBS)
target_link_libraries(ngraph LINK_PRIVATE ${LLVM_LINK_LIBS}) target_link_libraries(ngraph LINK_PRIVATE ${LLVM_LINK_LIBS})
endif() endif()
...@@ -137,8 +140,10 @@ if (APPLE) ...@@ -137,8 +140,10 @@ if (APPLE)
set_property(TARGET ngraph PROPERTY PREFIX "lib") set_property(TARGET ngraph PROPERTY PREFIX "lib")
set_property(TARGET ngraph PROPERTY OUTPUT_NAME "ngraph.so") set_property(TARGET ngraph PROPERTY OUTPUT_NAME "ngraph.so")
set_property(TARGET ngraph PROPERTY SUFFIX "") set_property(TARGET ngraph PROPERTY SUFFIX "")
else() endif()
include_directories("${MKLDNN_INCLUDE_DIR}")
if(NGRAPH_CPU_ENABLE AND MKLDNN_LIB_DIR)
target_link_libraries(ngraph LINK_PRIVATE mkldnn)
endif() endif()
#----------------------------------------------------------------------------------------------- #-----------------------------------------------------------------------------------------------
...@@ -178,6 +183,10 @@ endif() ...@@ -178,6 +183,10 @@ endif()
add_dependencies(ngraph eigen) add_dependencies(ngraph eigen)
if(NOT LLVM_PACKAGED AND LLVM_INCLUDE_DIR) if(NGRAPH_CPU_ENABLE AND LLVM_INCLUDE_DIR)
add_dependencies(ngraph ext_llvm) add_dependencies(ngraph ext_llvm)
endif() endif()
if(NGRAPH_CPU_ENABLE AND MKLDNN_INCLUDE_DIR)
add_dependencies(ngraph ext_mkldnn)
endif()
...@@ -145,10 +145,19 @@ std::unique_ptr<llvm::Module> execution_state::compile(const string& source, con ...@@ -145,10 +145,19 @@ std::unique_ptr<llvm::Module> execution_state::compile(const string& source, con
LO->OpenMP = 1; LO->OpenMP = 1;
LO->OpenMPUseTLS = 1; LO->OpenMPUseTLS = 1;
// CodeGen options
auto& CGO = Clang->getInvocation().getCodeGenOpts();
CGO.OptimizationLevel = 3;
CGO.RelocationModel = "static";
CGO.ThreadModel = "posix";
CGO.FloatABI = "hard";
CGO.OmitLeafFramePointer = 1;
CGO.VectorizeLoop = 1;
CGO.VectorizeSLP = 1;
CGO.CXAAtExit = 0;
if (debuginfo_enabled) if (debuginfo_enabled)
{ {
// CodeGen options
auto& CGO = Clang->getInvocation().getCodeGenOpts();
CGO.setDebugInfo(codegenoptions::FullDebugInfo); CGO.setDebugInfo(codegenoptions::FullDebugInfo);
} }
...@@ -163,6 +172,12 @@ std::unique_ptr<llvm::Module> execution_state::compile(const string& source, con ...@@ -163,6 +172,12 @@ std::unique_ptr<llvm::Module> execution_state::compile(const string& source, con
// Enable various target features // Enable various target features
// Most of these are for Eigen // Most of these are for Eigen
auto& TO = Clang->getInvocation().getTargetOpts(); auto& TO = Clang->getInvocation().getTargetOpts();
// TODO: This needs to be configurable and selected carefully
TO.CPU = "broadwell";
TO.FeaturesAsWritten.emplace_back("+sse");
TO.FeaturesAsWritten.emplace_back("+sse2");
TO.FeaturesAsWritten.emplace_back("+sse3");
TO.FeaturesAsWritten.emplace_back("+ssse3");
TO.FeaturesAsWritten.emplace_back("+sse4.1"); TO.FeaturesAsWritten.emplace_back("+sse4.1");
TO.FeaturesAsWritten.emplace_back("+sse4.2"); TO.FeaturesAsWritten.emplace_back("+sse4.2");
TO.FeaturesAsWritten.emplace_back("+avx"); TO.FeaturesAsWritten.emplace_back("+avx");
......
...@@ -22,12 +22,13 @@ using namespace ngraph::runtime::cpu; ...@@ -22,12 +22,13 @@ using namespace ngraph::runtime::cpu;
CallFrame::CallFrame(EntryPoint compiled_function, CallFrame::CallFrame(EntryPoint compiled_function,
size_t n_outputs, size_t n_outputs,
size_t n_inputs, size_t n_inputs,
const TensorViewPtrs& temps) const TensorViewPtrs& temps,
const std::vector<std::shared_ptr<CallFrame>>& callees)
: m_n_outputs(n_outputs) : m_n_outputs(n_outputs)
, m_n_inputs(n_inputs) , m_n_inputs(n_inputs)
, m_tensor_views(n_inputs + n_outputs + temps.size()) , m_tensor_views(n_outputs + n_inputs + temps.size())
, m_compiled_function(compiled_function) , m_compiled_function(compiled_function)
, m_callees(callees)
{ {
copy(temps.begin(), temps.end(), m_tensor_views.begin() + m_n_outputs + m_n_inputs); copy(temps.begin(), temps.end(), m_tensor_views.begin() + m_n_outputs + m_n_inputs);
} }
...@@ -40,7 +41,7 @@ void CallFrame::tensor_call( ...@@ -40,7 +41,7 @@ void CallFrame::tensor_call(
copy(inputs.begin(), inputs.end(), m_tensor_views.begin() + m_n_outputs); copy(inputs.begin(), inputs.end(), m_tensor_views.begin() + m_n_outputs);
// Invoke compiled computation // Invoke compiled computation
m_compiled_function(this, m_tensor_views); m_compiled_function(this, m_tensor_views, m_callees);
// Don't hold onto inputs/outputs // Don't hold onto inputs/outputs
fill_n(m_tensor_views.begin(), m_n_outputs + m_n_inputs, nullptr); fill_n(m_tensor_views.begin(), m_n_outputs + m_n_inputs, nullptr);
......
...@@ -31,8 +31,10 @@ namespace ngraph ...@@ -31,8 +31,10 @@ namespace ngraph
namespace cpu namespace cpu
{ {
class CallFrame; class CallFrame;
using EntryPoint = std::function<void(ngraph::runtime::cpu::CallFrame*, using EntryPoint = std::function<void(ngraph::runtime::cpu::CallFrame*,
ngraph::runtime::TensorViewPtrs&)>; ngraph::runtime::TensorViewPtrs&,
const std::vector<std::shared_ptr<CallFrame>>&)>;
// Compile and execute graphs // Compile and execute graphs
class CallFrame : public ngraph::runtime::CallFrame class CallFrame : public ngraph::runtime::CallFrame
...@@ -41,7 +43,8 @@ namespace ngraph ...@@ -41,7 +43,8 @@ namespace ngraph
CallFrame(EntryPoint compiled_function, CallFrame(EntryPoint compiled_function,
size_t n_outputs, size_t n_outputs,
size_t n_inputs, size_t n_inputs,
const TensorViewPtrs& temps); const TensorViewPtrs& temps,
const std::vector<std::shared_ptr<CallFrame>>& callees);
/// @brief Invoke the function with values matching the signature of the function. /// @brief Invoke the function with values matching the signature of the function.
/// ///
...@@ -73,6 +76,7 @@ namespace ngraph ...@@ -73,6 +76,7 @@ namespace ngraph
TensorViewPtrs m_tensor_views; TensorViewPtrs m_tensor_views;
bool m_return; bool m_return;
EntryPoint m_compiled_function; EntryPoint m_compiled_function;
std::vector<std::shared_ptr<CallFrame>> m_callees;
}; };
} }
} }
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/runtime/cpu/cpu_kernels.hpp"
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/types/element_type.hpp"
// CBLAS types and wrappers
namespace cblas
{
enum class Layout
{
RowMajor = 101,
ColMajor = 102
};
enum class Transpose
{
None = 111,
Transpose = 112,
ConjTrans = 113
};
enum class UpperLower
{
Upper = 121,
Lower = 122
};
enum class Diag
{
NonUnit = 131,
Unit = 132
};
enum class Side
{
Left = 141,
Right = 142
};
enum class Storage
{
Packed = 151
};
enum class Ident
{
AMatrix = 161,
BMatrix = 162
};
enum class Offset
{
RowOffset = 171,
ColOffset = 172,
FixOffset = 173
};
extern "C" {
void cblas_sgemm(const Layout layout,
const Transpose TransA,
const Transpose TransB,
const ngraph::element::Int64::type M,
const ngraph::element::Int64::type N,
const ngraph::element::Int64::type K,
const ngraph::element::Float32::type alpha,
const ngraph::element::Float32::type* A,
const ngraph::element::Int64::type lda,
const ngraph::element::Float32::type* B,
const ngraph::element::Int64::type ldb,
const ngraph::element::Float32::type beta,
ngraph::element::Float32::type* C,
const ngraph::element::Int64::type ldc);
}
}
namespace mkl
{
extern "C" {
void MKL_Somatcopy(char ordering,
char trans,
size_t rows,
size_t cols,
const ngraph::element::Float32::type alpha,
const ngraph::element::Float32::type* A,
size_t lda,
ngraph::element::Float32::type* B,
size_t ldb);
}
}
This diff is collapsed.
...@@ -61,6 +61,7 @@ namespace ngraph ...@@ -61,6 +61,7 @@ namespace ngraph
void EMITTER_DECL(EmitLessEq); void EMITTER_DECL(EmitLessEq);
void EMITTER_DECL(EmitLog); void EMITTER_DECL(EmitLog);
void EMITTER_DECL(EmitMaximum); void EMITTER_DECL(EmitMaximum);
void EMITTER_DECL(EmitMinimum);
void EMITTER_DECL(EmitNegative); void EMITTER_DECL(EmitNegative);
void EMITTER_DECL(EmitNotEqual); void EMITTER_DECL(EmitNotEqual);
void EMITTER_DECL(EmitSelect); void EMITTER_DECL(EmitSelect);
...@@ -75,6 +76,23 @@ namespace ngraph ...@@ -75,6 +76,23 @@ namespace ngraph
void EMITTER_DECL(EmitParameterizedConstantUInt64); void EMITTER_DECL(EmitParameterizedConstantUInt64);
void EMITTER_DECL(EmitBroadcast); void EMITTER_DECL(EmitBroadcast);
void EMITTER_DECL(EmitConvert); void EMITTER_DECL(EmitConvert);
void EMITTER_DECL(EmitConstant);
void EMITTER_DECL(EmitReshape);
void EMITTER_DECL(EmitFunctionCall);
void EMITTER_DECL(EmitReduce);
void EMITTER_DECL(EmitSign);
void EMITTER_DECL(EmitSlice);
void EMITTER_DECL(EmitSum);
void EMITTER_DECL(EmitExp);
void EMITTER_DECL(EmitSin);
void EMITTER_DECL(EmitSinh);
void EMITTER_DECL(EmitCos);
void EMITTER_DECL(EmitCosh);
void EMITTER_DECL(EmitTan);
void EMITTER_DECL(EmitTanh);
void EMITTER_DECL(EmitAsin);
void EMITTER_DECL(EmitAcos);
void EMITTER_DECL(EmitAtan);
}; };
} }
} }
......
...@@ -27,14 +27,20 @@ ...@@ -27,14 +27,20 @@
#include "ngraph/function.hpp" #include "ngraph/function.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/ops/abs.hpp" #include "ngraph/ops/abs.hpp"
#include "ngraph/ops/acos.hpp"
#include "ngraph/ops/add.hpp" #include "ngraph/ops/add.hpp"
#include "ngraph/ops/asin.hpp"
#include "ngraph/ops/atan.hpp"
#include "ngraph/ops/broadcast.hpp" #include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/concatenate.hpp" #include "ngraph/ops/concatenate.hpp"
#include "ngraph/ops/constant.hpp" #include "ngraph/ops/constant.hpp"
#include "ngraph/ops/convert.hpp" #include "ngraph/ops/convert.hpp"
#include "ngraph/ops/cos.hpp"
#include "ngraph/ops/cosh.hpp"
#include "ngraph/ops/divide.hpp" #include "ngraph/ops/divide.hpp"
#include "ngraph/ops/dot.hpp" #include "ngraph/ops/dot.hpp"
#include "ngraph/ops/equal.hpp" #include "ngraph/ops/equal.hpp"
#include "ngraph/ops/exp.hpp"
#include "ngraph/ops/function_call.hpp" #include "ngraph/ops/function_call.hpp"
#include "ngraph/ops/get_tuple_element.hpp" #include "ngraph/ops/get_tuple_element.hpp"
#include "ngraph/ops/greater.hpp" #include "ngraph/ops/greater.hpp"
...@@ -43,12 +49,21 @@ ...@@ -43,12 +49,21 @@
#include "ngraph/ops/less_eq.hpp" #include "ngraph/ops/less_eq.hpp"
#include "ngraph/ops/log.hpp" #include "ngraph/ops/log.hpp"
#include "ngraph/ops/maximum.hpp" #include "ngraph/ops/maximum.hpp"
#include "ngraph/ops/minimum.hpp"
#include "ngraph/ops/multiply.hpp" #include "ngraph/ops/multiply.hpp"
#include "ngraph/ops/negative.hpp" #include "ngraph/ops/negative.hpp"
#include "ngraph/ops/not_equal.hpp" #include "ngraph/ops/not_equal.hpp"
#include "ngraph/ops/reduce.hpp" #include "ngraph/ops/reduce.hpp"
#include "ngraph/ops/reshape.hpp"
#include "ngraph/ops/select.hpp" #include "ngraph/ops/select.hpp"
#include "ngraph/ops/sign.hpp"
#include "ngraph/ops/sin.hpp"
#include "ngraph/ops/sinh.hpp"
#include "ngraph/ops/slice.hpp"
#include "ngraph/ops/subtract.hpp" #include "ngraph/ops/subtract.hpp"
#include "ngraph/ops/sum.hpp"
#include "ngraph/ops/tan.hpp"
#include "ngraph/ops/tanh.hpp"
#include "ngraph/ops/tuple.hpp" #include "ngraph/ops/tuple.hpp"
#include "ngraph/pass/assign_layout.hpp" #include "ngraph/pass/assign_layout.hpp"
#include "ngraph/pass/assign_tensors.hpp" #include "ngraph/pass/assign_tensors.hpp"
...@@ -84,6 +99,7 @@ static const OpMap dispatcher{ ...@@ -84,6 +99,7 @@ static const OpMap dispatcher{
{TI(ngraph::op::LessEq), &Emitter::EmitLessEq}, {TI(ngraph::op::LessEq), &Emitter::EmitLessEq},
{TI(ngraph::op::Log), &Emitter::EmitLog}, {TI(ngraph::op::Log), &Emitter::EmitLog},
{TI(ngraph::op::Maximum), &Emitter::EmitMaximum}, {TI(ngraph::op::Maximum), &Emitter::EmitMaximum},
{TI(ngraph::op::Minimum), &Emitter::EmitMinimum},
{TI(ngraph::op::Negative), &Emitter::EmitNegative}, {TI(ngraph::op::Negative), &Emitter::EmitNegative},
{TI(ngraph::op::NotEqual), &Emitter::EmitNotEqual}, {TI(ngraph::op::NotEqual), &Emitter::EmitNotEqual},
{TI(ngraph::op::Select), &Emitter::EmitSelect}, {TI(ngraph::op::Select), &Emitter::EmitSelect},
...@@ -106,6 +122,23 @@ static const OpMap dispatcher{ ...@@ -106,6 +122,23 @@ static const OpMap dispatcher{
&Emitter::EmitParameterizedConstantUInt64}, &Emitter::EmitParameterizedConstantUInt64},
{TI(ngraph::op::Broadcast), &Emitter::EmitBroadcast}, {TI(ngraph::op::Broadcast), &Emitter::EmitBroadcast},
{TI(ngraph::op::Convert), &Emitter::EmitConvert}, {TI(ngraph::op::Convert), &Emitter::EmitConvert},
{TI(ngraph::op::Constant), &Emitter::EmitConstant},
{TI(ngraph::op::Reshape), &Emitter::EmitReshape},
{TI(ngraph::op::FunctionCall), &Emitter::EmitFunctionCall},
{TI(ngraph::op::Reduce), &Emitter::EmitReduce},
{TI(ngraph::op::Sign), &Emitter::EmitSign},
{TI(ngraph::op::Slice), &Emitter::EmitSlice},
{TI(ngraph::op::Sum), &Emitter::EmitSum},
{TI(ngraph::op::Exp), &Emitter::EmitExp},
{TI(ngraph::op::Sin), &Emitter::EmitSin},
{TI(ngraph::op::Sinh), &Emitter::EmitSinh},
{TI(ngraph::op::Cos), &Emitter::EmitCos},
{TI(ngraph::op::Cosh), &Emitter::EmitCosh},
{TI(ngraph::op::Tan), &Emitter::EmitTan},
{TI(ngraph::op::Tanh), &Emitter::EmitTanh},
{TI(ngraph::op::Asin), &Emitter::EmitAsin},
{TI(ngraph::op::Acos), &Emitter::EmitAcos},
{TI(ngraph::op::Atan), &Emitter::EmitAtan},
}; };
#undef TI #undef TI
...@@ -174,7 +207,9 @@ void ExternalFunction::compile(FunctionMap& function_map) ...@@ -174,7 +207,9 @@ void ExternalFunction::compile(FunctionMap& function_map)
// Now we build the TU // Now we build the TU
Emitter emitter; Emitter emitter;
auto& TU = emitter.GetTU(); auto& TU = emitter.GetTU();
TU += R"( TU += R"(// Generated by the NGraph CPU backend
#include <algorithm>
#include <cmath>
#include <memory> #include <memory>
#include <vector> #include <vector>
...@@ -182,17 +217,18 @@ void ExternalFunction::compile(FunctionMap& function_map) ...@@ -182,17 +217,18 @@ void ExternalFunction::compile(FunctionMap& function_map)
#include "ngraph/descriptor/layout/dense_tensor_view_layout.hpp" #include "ngraph/descriptor/layout/dense_tensor_view_layout.hpp"
#include "ngraph/runtime/cpu/call_frame.hpp" #include "ngraph/runtime/cpu/call_frame.hpp"
#include "ngraph/runtime/cpu/cpu_kernels.hpp"
#include "ngraph/runtime/cpu/eigen_utils.hpp" #include "ngraph/runtime/cpu/eigen_utils.hpp"
#include "ngraph/runtime/tensor_view_info.hpp" #include "ngraph/runtime/utils.hpp"
void *__dso_handle = 0;
using namespace ngraph::element; using namespace ngraph::element;
using namespace ngraph::runtime; using namespace ngraph::runtime;
using namespace ngraph::runtime::cpu::eigen; using namespace ngraph::runtime::cpu::eigen;
extern "C" void __entrypoint(ngraph::runtime::cpu::CallFrame* call_frame, extern "C" void __entrypoint(ngraph::runtime::cpu::CallFrame* call_frame,
ngraph::runtime::TensorViewPtrs& tensor_views) ngraph::runtime::TensorViewPtrs& tensor_views,
const std::vector<std::shared_ptr<ngraph::runtime::cpu::CallFrame>>& callees)
{ {
)"; )";
...@@ -243,8 +279,10 @@ extern "C" void __entrypoint(ngraph::runtime::cpu::CallFrame* call_frame, ...@@ -243,8 +279,10 @@ extern "C" void __entrypoint(ngraph::runtime::cpu::CallFrame* call_frame,
assert(llvm_module); assert(llvm_module);
estate.add_module(llvm_module); estate.add_module(llvm_module);
estate.finalize(); estate.finalize();
compiled_function = estate.find_function<void( compiled_function =
ngraph::runtime::cpu::CallFrame*, ngraph::runtime::TensorViewPtrs&)>("__entrypoint"); estate.find_function<void(ngraph::runtime::cpu::CallFrame*,
ngraph::runtime::TensorViewPtrs&,
const std::vector<std::shared_ptr<CallFrame>>&)>("__entrypoint");
assert(compiled_function); assert(compiled_function);
m_is_compiled = true; m_is_compiled = true;
...@@ -322,5 +360,5 @@ shared_ptr<ngraph::runtime::CallFrame> ExternalFunction::make_call_frame() ...@@ -322,5 +360,5 @@ shared_ptr<ngraph::runtime::CallFrame> ExternalFunction::make_call_frame()
#undef M #undef M
} }
return make_shared<ngraph::runtime::cpu::CallFrame>( return make_shared<ngraph::runtime::cpu::CallFrame>(
compiled_function, m_n_outputs, m_n_inputs, temps); compiled_function, m_n_outputs, m_n_inputs, temps, callees);
} }
...@@ -47,8 +47,10 @@ namespace ngraph ...@@ -47,8 +47,10 @@ namespace ngraph
using OpMap = std::unordered_map<std::type_index, OpFunction>; using OpMap = std::unordered_map<std::type_index, OpFunction>;
using EntryPoint = std::function<void(ngraph::runtime::cpu::CallFrame*, using EntryPoint = std::function<void(
ngraph::runtime::TensorViewPtrs&)>; ngraph::runtime::cpu::CallFrame*,
ngraph::runtime::TensorViewPtrs&,
const std::vector<std::shared_ptr<ngraph::runtime::cpu::CallFrame>>&)>;
class ExternalFunction : public ngraph::runtime::ExternalFunction class ExternalFunction : public ngraph::runtime::ExternalFunction
{ {
...@@ -56,7 +58,7 @@ namespace ngraph ...@@ -56,7 +58,7 @@ namespace ngraph
ExternalFunction(const std::shared_ptr<ngraph::Function>& function, ExternalFunction(const std::shared_ptr<ngraph::Function>& function,
bool release_function = true); bool release_function = true);
std::shared_ptr<ngraph::runtime::CallFrame> make_call_frame(); std::shared_ptr<ngraph::runtime::CallFrame> make_call_frame();
std::vector<std::shared_ptr<CallFrame>>& get_callees() { return callees; }
protected: protected:
void compile(FunctionMap& function_map); void compile(FunctionMap& function_map);
...@@ -64,6 +66,7 @@ namespace ngraph ...@@ -64,6 +66,7 @@ namespace ngraph
size_t m_n_outputs; size_t m_n_outputs;
ngraph::descriptor::TensorViewPtrs m_temp_views; ngraph::descriptor::TensorViewPtrs m_temp_views;
EntryPoint compiled_function; EntryPoint compiled_function;
std::vector<std::shared_ptr<CallFrame>> callees;
}; };
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment