Unverified Commit 66b05785 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Merge pull request #2378 from NervanaSystems/bob/backend_api3

New Backend API attempt #3
parents 08c4c57c 93226c4b
......@@ -28,4 +28,5 @@ else:
sys.setdlopenflags(flags)
from _pyngraph.runtime import Backend
from _pyngraph.runtime import Executable
from _pyngraph.runtime import Tensor
......@@ -20,7 +20,7 @@ from typing import List, Union
import numpy as np
from ngraph.impl import Function, Node, Shape, serialize, util
from ngraph.impl.runtime import Backend, Tensor
from ngraph.impl.runtime import Backend, Executable, Tensor
from ngraph.utils.types import get_dtype, NumericData
from ngraph.exceptions import UserInputError
......@@ -93,7 +93,7 @@ class Computation(object):
value = np.array(value)
Computation._write_ndarray_to_tensor_view(value, tensor_view)
self.runtime.backend.call(self.handle, self.result_views, self.tensor_views)
self.handle.call(self.result_views, self.tensor_views)
results = []
for result_view in self.result_views:
......
......@@ -35,23 +35,7 @@ void regclass_pyngraph_runtime_Backend(py::module m)
const ngraph::element::Type&, const ngraph::Shape&)) &
ngraph::runtime::Backend::create_tensor);
backend.def("compile",
(std::shared_ptr<ngraph::Function>(ngraph::runtime::Backend::*)(
(std::shared_ptr<ngraph::runtime::Executable>(ngraph::runtime::Backend::*)(
std::shared_ptr<ngraph::Function>)) &
ngraph::runtime::Backend::compile);
backend.def("call",
(bool (ngraph::runtime::Backend::*)(
std::shared_ptr<ngraph::Function>,
const std::vector<std::shared_ptr<ngraph::runtime::Tensor>>&,
const std::vector<std::shared_ptr<ngraph::runtime::Tensor>>&)) &
ngraph::runtime::Backend::call);
backend.def("remove_compiled_function",
(void (ngraph::runtime::Backend::*)(std::shared_ptr<ngraph::Function>)) &
ngraph::runtime::Backend::remove_compiled_function);
backend.def("enable_performance_data",
(void (ngraph::runtime::Backend::*)(std::shared_ptr<ngraph::Function>, bool)) &
ngraph::runtime::Backend::enable_performance_data);
backend.def("get_performance_data",
(std::vector<ngraph::runtime::PerformanceCounter>(ngraph::runtime::Backend::*)(
std::shared_ptr<ngraph::Function>)) &
ngraph::runtime::Backend::get_performance_data);
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/runtime/backend.hpp"
#include "ngraph/runtime/tensor.hpp"
#include "pyngraph/runtime/executable.hpp"
namespace py = pybind11;
void regclass_pyngraph_runtime_Executable(py::module m)
{
py::class_<ngraph::runtime::Executable, std::shared_ptr<ngraph::runtime::Executable>>
executable(m, "Executable");
executable.doc() = "ngraph.impl.runtime.Executable wraps ngraph::runtime::Executable";
executable.def("call",
(bool (ngraph::runtime::Executable::*)(
const std::vector<std::shared_ptr<ngraph::runtime::Tensor>>&,
const std::vector<std::shared_ptr<ngraph::runtime::Tensor>>&)) &
ngraph::runtime::Executable::call);
executable.def(
"get_performance_data",
(std::vector<ngraph::runtime::PerformanceCounter>(ngraph::runtime::Executable::*)()) &
ngraph::runtime::Executable::get_performance_data);
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_runtime_Executable(py::module m);
......@@ -25,4 +25,5 @@ void regmodule_pyngraph_runtime(py::module m)
m.def_submodule("runtime", "Package ngraph.impl.runtime wraps ngraph::runtime");
regclass_pyngraph_runtime_Tensor(m_runtime);
regclass_pyngraph_runtime_Backend(m_runtime);
regclass_pyngraph_runtime_Executable(m_runtime);
}
......@@ -18,6 +18,7 @@
#include <pybind11/pybind11.h>
#include "pyngraph/runtime/backend.hpp"
#include "pyngraph/runtime/executable.hpp"
#include "pyngraph/runtime/tensor.hpp"
namespace py = pybind11;
......
......@@ -228,6 +228,7 @@ sources = [
'pyngraph/ops/softmax.cpp',
'pyngraph/ops/result.cpp',
'pyngraph/runtime/backend.cpp',
'pyngraph/runtime/executable.cpp',
'pyngraph/runtime/regmodule_pyngraph_runtime.cpp',
'pyngraph/runtime/tensor.cpp',
'pyngraph/passes/manager.cpp',
......
This diff is collapsed.
......@@ -47,26 +47,12 @@ namespace ngraph
}
const std::string& get_type() const { return m_type; }
runtime::Handle compile(const std::shared_ptr<Function>& function) const
std::shared_ptr<runtime::Executable>
compile(const std::shared_ptr<Function>& function) const
{
return get().compile(function);
}
bool call(const std::shared_ptr<Function>& function,
const std::vector<std::shared_ptr<runtime::Tensor>>& outputs,
const std::vector<std::shared_ptr<runtime::Tensor>>& inputs) const
{
return get().call(function, outputs, inputs);
}
bool call_with_validate(
const std::shared_ptr<Function>& function,
const std::vector<std::shared_ptr<runtime::Tensor>>& outputs,
const std::vector<std::shared_ptr<runtime::Tensor>>& inputs) const
{
return get().call_with_validate(function, outputs, inputs);
}
private:
std::string m_type{};
mutable std::shared_ptr<runtime::Backend> m_backend{nullptr};
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <memory> // std::shared_ptr
#include <string> // std::string
#include <utility> // std::move
#include <vector> // std::vector
#include "ngraph/function.hpp"
#include "ngraph/runtime/backend.hpp"
#include "ngraph/runtime/executable.hpp"
#include "ngraph/runtime/tensor.hpp"
namespace ngraph
{
namespace onnxifi
{
/// \brief ONNXIFI extensions to nGraph Executable
class Executable
{
public:
Executable(const Executable&) = delete;
Executable& operator=(const Executable&) = delete;
Executable(Executable&&) = default;
Executable& operator=(Executable&&) = default;
explicit Executable(const std::shared_ptr<runtime::Executable>& executable)
: m_executable{executable}
{
}
bool call(const std::vector<std::shared_ptr<runtime::Tensor>>& outputs,
const std::vector<std::shared_ptr<runtime::Tensor>>& inputs) const
{
return m_executable->call(outputs, inputs);
}
bool call_with_validate(
const std::vector<std::shared_ptr<runtime::Tensor>>& outputs,
const std::vector<std::shared_ptr<runtime::Tensor>>& inputs) const
{
return m_executable->call_with_validate(outputs, inputs);
}
private:
mutable std::shared_ptr<runtime::Executable> m_executable{nullptr};
};
} // namespace onnxifi
} // namespace ngraph
......@@ -39,78 +39,127 @@ vector<string> runtime::Backend::get_registered_devices()
return BackendManager::get_registered_backends();
}
void runtime::Backend::remove_compiled_function(shared_ptr<Function> func)
bool runtime::Backend::is_supported(const Node& node) const
{
// The default behavior is that a backend does not support any ops. If this is not the case
// then override this method and enhance.
return false;
}
vector<ngraph::runtime::PerformanceCounter>
runtime::Backend::get_performance_data(shared_ptr<Function> func) const
runtime::Executable::Executable()
{
return vector<PerformanceCounter>();
}
void runtime::Backend::validate(shared_ptr<const Function> function,
const vector<shared_ptr<runtime::Tensor>>& outputs,
runtime::Executable::~Executable()
{
}
bool runtime::Executable::call_with_validate(const vector<shared_ptr<runtime::Tensor>>& outputs,
const vector<shared_ptr<runtime::Tensor>>& inputs)
{
const ParameterVector& input_parameters = function->get_parameters();
if (input_parameters.size() != inputs.size())
validate(outputs, inputs);
return call(outputs, inputs);
}
void runtime::Executable::validate(const vector<std::shared_ptr<runtime::Tensor>>& outputs,
const vector<std::shared_ptr<runtime::Tensor>>& inputs)
{
const ParameterVector& parameters = get_parameters();
const ResultVector& results = get_results();
if (parameters.size() != inputs.size())
{
stringstream ss;
ss << "Call input count " << inputs.size() << " does not match Function's Parameter count "
<< input_parameters.size();
<< parameters.size();
throw runtime_error(ss.str());
}
if (function->get_output_size() != outputs.size())
if (results.size() != outputs.size())
{
stringstream ss;
ss << "Call output count " << outputs.size() << " does not match Function's Result count "
<< function->get_output_size();
<< results.size();
throw runtime_error(ss.str());
}
for (size_t i = 0; i < input_parameters.size(); i++)
for (size_t i = 0; i < parameters.size(); i++)
{
if (input_parameters[i]->get_element_type() != inputs[i]->get_element_type())
if (parameters[i]->get_element_type() != inputs[i]->get_element_type())
{
stringstream ss;
ss << "Input " << i << " type '" << inputs[i]->get_element_type()
<< "' does not match Parameter type '" << input_parameters[i]->get_element_type()
<< "'";
<< "' does not match Parameter type '" << parameters[i]->get_element_type() << "'";
throw runtime_error(ss.str());
}
if (input_parameters[i]->get_shape() != inputs[i]->get_shape())
if (parameters[i]->get_shape() != inputs[i]->get_shape())
{
stringstream ss;
ss << "Input " << i << " shape {" << join(inputs[i]->get_shape())
<< "} does not match Parameter shape {" << join(input_parameters[i]->get_shape())
<< "}";
<< "} does not match Parameter shape {" << join(parameters[i]->get_shape()) << "}";
throw runtime_error(ss.str());
}
}
for (size_t i = 0; i < function->get_output_size(); i++)
for (size_t i = 0; i < results.size(); i++)
{
if (function->get_output_element_type(i) != outputs[i]->get_element_type())
if (results[i]->get_element_type() != outputs[i]->get_element_type())
{
stringstream ss;
ss << "Output " << i << " type '" << outputs[i]->get_element_type()
<< "' does not match Result type '" << function->get_output_element_type(i) << "'";
<< "' does not match Result type '" << results[i]->get_element_type() << "'";
throw runtime_error(ss.str());
}
if (function->get_output_shape(i) != outputs[i]->get_shape())
if (results[i]->get_shape() != outputs[i]->get_shape())
{
stringstream ss;
ss << "Output " << i << " shape {" << join(outputs[i]->get_shape())
<< "} does not match Result shape {" << join(function->get_output_shape(i)) << "}";
<< "} does not match Result shape {" << join(results[i]->get_shape()) << "}";
throw runtime_error(ss.str());
}
}
}
bool runtime::Backend::is_supported(const Node& node) const
const ngraph::ParameterVector& runtime::Executable::get_parameters() const
{
return m_parameters;
}
const ngraph::ResultVector& runtime::Executable::get_results() const
{
return m_results;
}
void runtime::Executable::set_parameters_and_results(const Function& func)
{
m_parameters = func.get_parameters();
m_results = func.get_results();
}
vector<runtime::PerformanceCounter> runtime::Executable::get_performance_data() const
{
return vector<PerformanceCounter>();
}
bool runtime::Backend::is_supported_property(const Property prop) const
{
// The default behavior is that a backend does not support any ops. If this is not the case
// then override this method and enhance.
return false;
}
void runtime::Backend::remove_compiled_function(std::shared_ptr<Executable> exec)
{
}
bool runtime::Backend::call_with_validate(
std::shared_ptr<Executable> exec,
const std::vector<std::shared_ptr<runtime::Tensor>>& outputs,
const std::vector<std::shared_ptr<runtime::Tensor>>& inputs)
{
return exec->call_with_validate(outputs, inputs);
}
bool runtime::Backend::call_with_validate(
const std::unique_ptr<Executable>& exec,
const std::vector<std::shared_ptr<runtime::Tensor>>& outputs,
const std::vector<std::shared_ptr<runtime::Tensor>>& inputs)
{
return exec->call_with_validate(outputs, inputs);
}
......@@ -30,7 +30,8 @@ namespace ngraph
class ExternalFunction;
class Tensor;
class Backend;
using Handle = std::shared_ptr<Function>;
class Executable;
using Handle = std::shared_ptr<Executable>;
}
}
......@@ -81,43 +82,8 @@ public:
/// \brief Compiles a Function.
/// \param func The function to compile
/// \returns compiled function or nullptr on failure
virtual Handle compile(std::shared_ptr<Function> func) = 0;
/// \brief Executes a single iteration of a Function. If func is not compiled the call will
/// compile it.
/// \param func The function to execute
/// \returns true if iteration is successful, false otherwise
virtual bool call(std::shared_ptr<Function> func,
const std::vector<std::shared_ptr<runtime::Tensor>>& outputs,
const std::vector<std::shared_ptr<runtime::Tensor>>& inputs) = 0;
/// \brief Executes a single iteration of a Function. If func is not compiled the call will
/// compile it. Optionally validates the inputs and outputs against the function graph.
/// \param func The function to execute
/// \returns true if iteration is successful, false otherwise
bool call_with_validate(std::shared_ptr<Function> func,
const std::vector<std::shared_ptr<runtime::Tensor>>& outputs,
const std::vector<std::shared_ptr<runtime::Tensor>>& inputs)
{
validate(func, outputs, inputs);
return call(func, outputs, inputs);
}
/// \brief Compiled functions may be cached. This function removes a compiled function
/// from the cache.
/// \param func The function to execute
virtual void remove_compiled_function(std::shared_ptr<Function> func);
/// \brief Enable the collection of per-op performance information on a specified Function.
/// Data collection is via the `get_performance_data` method.
/// \param func The function to collect perfomance data on.
/// \param enable Set to true to enable or false to disable data collection
virtual void enable_performance_data(std::shared_ptr<Function> func, bool enable) {}
/// \brief Collect performance information gathered on a Function.
/// \param func The function to get collected data.
/// \returns Vector of PerformanceCounter information.
virtual std::vector<PerformanceCounter>
get_performance_data(std::shared_ptr<Function> func) const;
virtual std::shared_ptr<Executable> compile(std::shared_ptr<Function> func,
bool enable_performance_data = false) = 0;
/// \brief Test if a backend is capable of supporting an op
/// \param node is the op to test.
......@@ -133,8 +99,64 @@ public:
/// \brief Test if a backend particular property is supported
/// \param prop is the feature to test.
/// \returns true if the property is supported, false otherwise.
virtual bool is_supported_property(const Property prop) const { return false; }
void validate(std::shared_ptr<const Function> func,
virtual bool is_supported_property(const Property prop) const;
virtual void remove_compiled_function(std::shared_ptr<Executable> exec);
/// The following methods are temporary hacks to reduce the number of changes in this PR
/// They will be removed in a follow-on PR
bool call_with_validate(std::shared_ptr<Executable> handle,
const std::vector<std::shared_ptr<runtime::Tensor>>& outputs,
const std::vector<std::shared_ptr<runtime::Tensor>>& inputs);
bool call_with_validate(const std::unique_ptr<Executable>& handle,
const std::vector<std::shared_ptr<runtime::Tensor>>& outputs,
const std::vector<std::shared_ptr<runtime::Tensor>>& inputs);
};
class ngraph::runtime::Executable
{
public:
Executable();
virtual ~Executable();
/// \param outputs vector of runtime::Tensor used as outputs
/// \param inputs vector of runtime::Tensor used as inputs
/// \returns true if iteration is successful, false otherwise
virtual bool call(const std::vector<std::shared_ptr<runtime::Tensor>>& outputs,
const std::vector<std::shared_ptr<runtime::Tensor>>& inputs) = 0;
/// \brief Executes a single iteration of a Function.
/// \param outputs vector of runtime::Tensor used as outputs
/// \param inputs vector of runtime::Tensor used as inputs
/// \returns true if iteration is successful, false otherwise
bool call_with_validate(const std::vector<std::shared_ptr<runtime::Tensor>>& outputs,
const std::vector<std::shared_ptr<runtime::Tensor>>& inputs);
/// \brief Collect performance information gathered on a Function.
/// \returns Vector of PerformanceCounter information.
virtual std::vector<PerformanceCounter> get_performance_data() const;
/// \brief Validates a Function.
/// \param outputs vector of runtime::Tensor used as outputs
/// \param inputs vector of runtime::Tensor used as inputs
void validate(const std::vector<std::shared_ptr<runtime::Tensor>>& outputs,
const std::vector<std::shared_ptr<runtime::Tensor>>& inputs);
/// \brief Query the input Parameters
/// \returns an ngraph::op::ParameterVector of all input parameters
const ngraph::ParameterVector& get_parameters() const;
/// \brief Query the output Results
/// \returns an ngraph::ResultVector of all input parameters
const ngraph::ResultVector& get_results() const;
protected:
/// \brief Called at the end of compile to the the values to be returned by get_parameters
/// and get_results
/// \param func The function with Results fully resolved.
void set_parameters_and_results(const Function& func);
private:
ngraph::ParameterVector m_parameters;
ngraph::ResultVector m_results;
};
......@@ -68,44 +68,52 @@ shared_ptr<runtime::Tensor> runtime::cpu::CPU_Backend::create_tensor(
return make_shared<runtime::cpu::CPUTensorView>(element_type, shape, memory_pointer, this);
}
runtime::Handle runtime::cpu::CPU_Backend::compile(shared_ptr<Function> func)
shared_ptr<runtime::Executable>
runtime::cpu::CPU_Backend::compile(shared_ptr<Function> func, bool performance_counters_enabled)
{
FunctionInstance& instance = m_function_map[func];
if (instance.m_external_function == nullptr)
shared_ptr<runtime::Executable> rc;
auto it = m_exec_map.find(func);
if (it != m_exec_map.end())
{
instance.m_external_function = make_shared<CPU_ExternalFunction>(func);
instance.m_external_function->m_emit_timing = instance.m_performance_counters_enabled;
auto cf = instance.m_external_function->make_call_frame();
instance.m_call_frame = dynamic_pointer_cast<CPU_CallFrame>(cf);
rc = it->second;
}
else
{
rc = make_shared<CPU_Executable>(func, performance_counters_enabled);
m_exec_map.insert({func, rc});
}
return func;
return rc;
}
std::shared_ptr<ngraph::runtime::cpu::CPU_CallFrame>
runtime::cpu::CPU_Backend::get_call_frame(std::shared_ptr<Function> func)
runtime::cpu::CPU_Executable::CPU_Executable(shared_ptr<Function> func,
bool performance_counters_enabled)
{
FunctionInstance& instance = m_function_map[func];
FunctionInstance& instance = m_function_instance;
if (instance.m_external_function == nullptr)
{
auto rc = compile(func);
if (!rc)
{
throw ngraph_error("couldn't compile a function");
}
instance.m_external_function = make_shared<CPU_ExternalFunction>(func);
instance.m_external_function->m_emit_timing = performance_counters_enabled;
auto cf = instance.m_external_function->make_call_frame();
instance.m_call_frame = dynamic_pointer_cast<CPU_CallFrame>(cf);
}
set_parameters_and_results(*func);
}
std::shared_ptr<ngraph::runtime::cpu::CPU_CallFrame> runtime::cpu::CPU_Executable::get_call_frame()
{
FunctionInstance& instance = m_function_instance;
return instance.m_call_frame;
}
bool runtime::cpu::CPU_Backend::call(shared_ptr<Function> func,
const vector<shared_ptr<runtime::Tensor>>& outputs,
bool runtime::cpu::CPU_Executable::call(const vector<shared_ptr<runtime::Tensor>>& outputs,
const vector<shared_ptr<runtime::Tensor>>& inputs)
{
bool rc = true;
FunctionInstance& instance = m_function_map[func];
FunctionInstance& instance = m_function_instance;
if (instance.m_external_function == nullptr)
{
NGRAPH_INFO;
throw runtime_error("compile() must be called before call().");
}
......@@ -114,36 +122,28 @@ bool runtime::cpu::CPU_Backend::call(shared_ptr<Function> func,
return rc;
}
void runtime::cpu::CPU_Backend::remove_compiled_function(shared_ptr<Function> func)
{
m_function_map.erase(func);
}
void runtime::cpu::CPU_Backend::enable_performance_data(shared_ptr<Function> func, bool enable)
void runtime::cpu::CPU_Backend::remove_compiled_function(shared_ptr<Executable> exec)
{
FunctionInstance& instance = m_function_map[func];
if (instance.m_external_function != nullptr)
for (auto it = m_exec_map.begin(); it != m_exec_map.end(); ++it)
{
throw runtime_error("Performance data collection must be enabled prior to compiling.");
if (it->second == exec)
{
m_exec_map.erase(it);
break;
}
}
instance.m_performance_counters_enabled = enable;
}
vector<runtime::PerformanceCounter>
runtime::cpu::CPU_Backend::get_performance_data(shared_ptr<Function> func) const
vector<runtime::PerformanceCounter> runtime::cpu::CPU_Executable::get_performance_data() const
{
vector<runtime::PerformanceCounter> rc;
auto it = m_function_map.find(func);
if (it != m_function_map.end())
{
const FunctionInstance& instance = it->second;
const FunctionInstance& instance = m_function_instance;
if (instance.m_external_function != nullptr)
{
rc.insert(rc.end(),
instance.m_external_function->get_perf_counters().begin(),
instance.m_external_function->get_perf_counters().end());
}
}
return rc;
}
......@@ -151,7 +151,6 @@ bool runtime::cpu::CPU_Backend::is_supported(const Node& op) const
{
return true;
}
bool runtime::cpu::CPU_Backend::is_supported_property(const Property prop) const
{
if (prop == Property::memory_attach)
......
......@@ -46,32 +46,39 @@ namespace ngraph
create_tensor(const ngraph::element::Type& element_type,
const Shape& shape) override;
Handle compile(std::shared_ptr<Function> func) override;
std::shared_ptr<ngraph::runtime::Executable>
compile(std::shared_ptr<Function> func,
bool enable_performance_counters = false) override;
bool call(std::shared_ptr<Function> func,
const std::vector<std::shared_ptr<runtime::Tensor>>& outputs,
const std::vector<std::shared_ptr<runtime::Tensor>>& inputs) override;
void remove_compiled_function(std::shared_ptr<Function> func) override;
std::shared_ptr<CPU_CallFrame> get_call_frame(std::shared_ptr<Function> func);
void enable_performance_data(std::shared_ptr<Function> func, bool enable) override;
std::vector<PerformanceCounter>
get_performance_data(std::shared_ptr<Function> func) const override;
void remove_compiled_function(std::shared_ptr<Executable> exec) override;
bool is_supported(const Node& node) const override;
bool is_supported_property(const Property prop) const override;
private:
std::unordered_map<std::shared_ptr<Function>, std::shared_ptr<Executable>>
m_exec_map;
};
class CPU_Executable : public runtime::Executable
{
public:
CPU_Executable(std::shared_ptr<Function> func, bool performance_counters_enabled);
bool call(const std::vector<std::shared_ptr<runtime::Tensor>>& outputs,
const std::vector<std::shared_ptr<runtime::Tensor>>& inputs) override;
std::shared_ptr<CPU_CallFrame> get_call_frame();
std::vector<PerformanceCounter> get_performance_data() const override;
private:
class FunctionInstance
{
public:
std::shared_ptr<CPU_ExternalFunction> m_external_function;
std::shared_ptr<CPU_CallFrame> m_call_frame;
std::shared_ptr<CPU_ExternalFunction> m_external_function = nullptr;
std::shared_ptr<CPU_CallFrame> m_call_frame = nullptr;
bool m_performance_counters_enabled = false;
};
std::map<std::shared_ptr<Function>, FunctionInstance> m_function_map;
} m_function_instance;
};
}
}
......
......@@ -92,6 +92,7 @@ namespace ngraph
friend class CPU_Backend;
friend class CPU_CallFrame;
friend class CPU_Debugger;
friend class CPU_Executable;
public:
enum class CPUTensorRole
......
......@@ -7,6 +7,7 @@ one_hot_vector_1_far_oob
one_hot_vector_1_fp_nonint
backwards_maxpool_n2_c1_hw5_3x3_str2_max_pad1x2_2x3
backwards_batch_norm_training
shape_of_scalar
shape_of_vector
shape_of_matrix
......
......@@ -49,7 +49,6 @@ extern "C" void delete_backend(runtime::Backend* backend)
runtime::gpu::GPU_Backend::GPU_Backend()
: runtime::Backend()
, m_context(new BackendContext())
{
}
......@@ -118,23 +117,42 @@ shared_ptr<runtime::Tensor> runtime::gpu::GPU_Backend::create_tensor(
return make_shared<runtime::gpu::GPUTensor>(element_type, shape, memory_pointer, this);
}
runtime::Handle runtime::gpu::GPU_Backend::compile(shared_ptr<Function> func)
shared_ptr<runtime::Executable> runtime::gpu::GPU_Backend::compile(shared_ptr<Function> func,
bool timing_enable)
{
FunctionInstance& instance = m_function_map[func];
shared_ptr<runtime::Executable> rc;
auto it = m_exec_map.find(func);
if (it != m_exec_map.end())
{
rc = it->second;
}
else
{
rc = make_shared<GPU_Executable>(func, timing_enable);
m_exec_map.insert({func, rc});
}
return rc;
}
runtime::gpu::GPU_Executable::GPU_Executable(shared_ptr<Function> func, bool enable_timing)
: m_context(new GPU_Backend::BackendContext())
{
FunctionInstance& instance = m_function_instance;
if (instance.m_compiled_function == nullptr)
{
m_context->bind_cuda_context_to_thread();
instance.m_compiled_function = runtime::gpu::GPUCompiledFunction::make(func, m_context);
instance.m_compiled_function->m_emit_timing = instance.m_performance_counters_enabled;
instance.m_compiled_function->m_emit_timing = enable_timing;
instance.m_compiled_function->compile();
instance.m_runtime = instance.m_compiled_function->m_runtime;
instance.m_inputs.resize(func->get_parameters().size());
instance.m_outputs.resize(func->get_output_size());
}
return func;
set_parameters_and_results(*func);
}
void runtime::gpu::GPU_Backend::initialize_io(void** target,
void runtime::gpu::GPU_Executable::initialize_io(void** target,
const vector<shared_ptr<runtime::Tensor>>& source)
{
for (size_t i = 0; i < source.size(); i++)
......@@ -152,11 +170,10 @@ void runtime::gpu::GPU_Backend::initialize_io(void** target,
}
}
bool runtime::gpu::GPU_Backend::call(shared_ptr<Function> func,
const vector<shared_ptr<runtime::Tensor>>& outputs,
bool runtime::gpu::GPU_Executable::call(const vector<shared_ptr<runtime::Tensor>>& outputs,
const vector<shared_ptr<runtime::Tensor>>& inputs)
{
FunctionInstance& instance = m_function_map[func];
FunctionInstance& instance = m_function_instance;
if (instance.m_compiled_function == nullptr)
{
throw runtime_error("compile() must be called before call().");
......@@ -175,34 +192,19 @@ bool runtime::gpu::GPU_Backend::call(shared_ptr<Function> func,
return true;
}
void runtime::gpu::GPU_Backend::remove_compiled_function(shared_ptr<Function> func)
{
m_function_map.erase(func);
}
void runtime::gpu::GPU_Backend::enable_performance_data(shared_ptr<Function> func, bool enable)
{
FunctionInstance& instance = m_function_map[func];
if (instance.m_compiled_function != nullptr)
{
throw runtime_error("Performance data collection must be enabled prior to compiling.");
}
instance.m_performance_counters_enabled = enable;
}
// void runtime::gpu::GPU_Backend::remove_compiled_function(shared_ptr<Function> func)
// {
// m_function_map.erase(func);
// }
vector<runtime::PerformanceCounter>
runtime::gpu::GPU_Backend::get_performance_data(shared_ptr<Function> func) const
vector<runtime::PerformanceCounter> runtime::gpu::GPU_Executable::get_performance_data() const
{
std::vector<runtime::PerformanceCounter> rc;
auto it = m_function_map.find(func);
if (it != m_function_map.end())
{
const FunctionInstance& instance = it->second;
const FunctionInstance& instance = m_function_instance;
if (instance.m_compiled_function != nullptr)
{
instance.m_compiled_function->get_performance_data(rc);
}
}
return rc;
}
......
......@@ -51,16 +51,8 @@ namespace ngraph
create_tensor(const ngraph::element::Type& element_type,
const Shape& shape) override;
Handle compile(std::shared_ptr<Function> func) override;
bool call(std::shared_ptr<Function> func,
const std::vector<std::shared_ptr<runtime::Tensor>>& outputs,
const std::vector<std::shared_ptr<runtime::Tensor>>& inputs) override;
void remove_compiled_function(std::shared_ptr<Function> func) override;
void enable_performance_data(std::shared_ptr<Function> func, bool enable) override;
std::vector<PerformanceCounter>
get_performance_data(std::shared_ptr<Function> func) const override;
std::shared_ptr<runtime::Executable> compile(std::shared_ptr<Function> func,
bool timing_enabled = false) override;
bool is_supported(const Node& node) const override;
......@@ -79,6 +71,21 @@ namespace ngraph
std::unique_ptr<CudaContextManager> m_cuda_manager;
};
private:
std::map<std::shared_ptr<Function>, std::shared_ptr<Executable>> m_exec_map;
};
class GPU_Executable : public Executable
{
public:
GPU_Executable(std::shared_ptr<Function> func, bool enable_timing);
bool call(const std::vector<std::shared_ptr<runtime::Tensor>>& outputs,
const std::vector<std::shared_ptr<runtime::Tensor>>& inputs) override;
// void remove_compiled_function(std::shared_ptr<Function> func) override;
std::vector<PerformanceCounter> get_performance_data() const override;
private:
class FunctionInstance
{
......@@ -88,7 +95,7 @@ namespace ngraph
EntryPoint m_runtime;
std::vector<void*> m_inputs;
std::vector<void*> m_outputs;
};
} m_function_instance;
/// \brief Convert a vector of Tensor into a vector of void* where each void*
/// points to a Tensor's data buffer.
......@@ -99,8 +106,7 @@ namespace ngraph
initialize_io(void** target,
const std::vector<std::shared_ptr<runtime::Tensor>>& source);
std::map<std::shared_ptr<Function>, FunctionInstance> m_function_map;
std::shared_ptr<BackendContext> m_context;
std::shared_ptr<GPU_Backend::BackendContext> m_context;
};
}
}
......
......@@ -49,6 +49,7 @@ namespace ngraph
class GPUCompiledFunction
{
friend class GPU_Backend;
friend class GPU_Executable;
public:
GPUCompiledFunction(
......
......@@ -64,14 +64,24 @@ static void node_modifiers(const Node& node, vector<string>& attributes)
}
}
runtime::Handle runtime::hybrid::HybridBackend::compile(shared_ptr<Function> func)
shared_ptr<runtime::Executable>
runtime::hybrid::HybridBackend::compile(shared_ptr<Function> func,
bool enable_performance_collection)
{
if (m_function_map.find(func) == m_function_map.end())
{
// Clone function
FunctionInstance instance;
instance.m_function = clone_function(*func);
return make_shared<HybridExecutable>(
m_backend_list, func, enable_performance_collection, m_debug_enabled);
}
runtime::hybrid::HybridExecutable::HybridExecutable(
const std::vector<std::shared_ptr<runtime::Backend>>& backend_list,
const shared_ptr<Function>& func,
bool enable_performance_collection,
bool debug_enabled)
: m_function{func}
, m_backend_list{backend_list}
, m_debug_enabled{debug_enabled}
{
{
// Run placement pass
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<runtime::hybrid::pass::DefaultPlacement>(m_backend_list);
......@@ -83,16 +93,15 @@ runtime::Handle runtime::hybrid::HybridBackend::compile(shared_ptr<Function> fun
{
pass_manager.register_pass<ngraph::pass::VisualizeTree>("graph.png", node_modifiers);
}
pass_manager.run_passes(instance.m_function);
pass_manager.run_passes(m_function);
// Split function to sub_functions
tie(instance.m_sub_functions, instance.m_map_parameter_to_result) =
runtime::hybrid::split_function_by_placement(instance.m_function);
m_function_map.insert({func, instance});
tie(m_sub_functions, m_map_parameter_to_result) =
runtime::hybrid::split_function_by_placement(m_function);
// Compile subfunctions in corresponding backends
size_t subfunction_number = 0;
for (shared_ptr<Function>& sub_function : instance.m_sub_functions)
for (shared_ptr<Function>& sub_function : m_sub_functions)
{
size_t placement = sub_function->get_placement();
if (m_debug_enabled)
......@@ -104,7 +113,8 @@ runtime::Handle runtime::hybrid::HybridBackend::compile(shared_ptr<Function> fun
pm.run_passes(sub_function);
}
auto backend = m_backend_list[placement];
backend->compile(sub_function);
shared_ptr<Executable> exec = backend->compile(sub_function);
m_executable_map[sub_function] = exec;
// Compile will replace nodes so we need to make one more pass through all
// ops to reset placement
......@@ -115,38 +125,29 @@ runtime::Handle runtime::hybrid::HybridBackend::compile(shared_ptr<Function> fun
}
}
return func;
set_parameters_and_results(*func);
}
bool runtime::hybrid::HybridBackend::call(shared_ptr<Function> func,
const vector<shared_ptr<runtime::Tensor>>& outputs,
bool runtime::hybrid::HybridExecutable::call(const vector<shared_ptr<runtime::Tensor>>& outputs,
const vector<shared_ptr<runtime::Tensor>>& inputs)
{
// Get FunctionInstance
bool rc = true;
using node_map_t = unordered_map<shared_ptr<Node>, shared_ptr<runtime::Tensor>>;
auto fit = m_function_map.find(func);
if (fit == m_function_map.end())
{
throw runtime_error("compile() must be called before call().");
}
FunctionInstance& instance = fit->second;
// Parameter and result node in sub_function maps to one Tensor
node_map_t map_node_to_tensor;
for (size_t i = 0; i < inputs.size(); ++i)
{
map_node_to_tensor[instance.m_function->get_parameters()[i]] = inputs[i];
map_node_to_tensor[m_function->get_parameters()[i]] = inputs[i];
}
for (size_t i = 0; i < outputs.size(); ++i)
{
map_node_to_tensor[instance.m_function->get_results()[i]] = outputs[i];
map_node_to_tensor[m_function->get_results()[i]] = outputs[i];
}
// Call subfunctions
for (const shared_ptr<Function>& sub_function : instance.m_sub_functions)
for (const shared_ptr<Function>& sub_function : m_sub_functions)
{
// Init backend
size_t placement = sub_function->get_placement();
......@@ -174,7 +175,7 @@ bool runtime::hybrid::HybridBackend::call(shared_ptr<Function> func,
else
{
// Handle temporary tensors that go between subgraphs
auto result_node = instance.m_map_parameter_to_result.at(parameter_node);
auto result_node = m_map_parameter_to_result.at(parameter_node);
auto result = map_node_to_tensor.at(result_node);
auto parameter = backend->create_tensor(parameter_node->get_element_type(),
parameter_node->get_shape());
......@@ -215,7 +216,8 @@ bool runtime::hybrid::HybridBackend::call(shared_ptr<Function> func,
}
// Call
backend->call(sub_function, results, parameters);
auto exec = m_executable_map[sub_function];
exec->call(results, parameters);
// Need to copy any results to the correct device
for (const auto& p : copy_back)
......@@ -231,7 +233,7 @@ bool runtime::hybrid::HybridBackend::is_supported(const Node& node) const
return true;
}
size_t runtime::hybrid::HybridBackend::get_placement(const runtime::Tensor* t)
size_t runtime::hybrid::HybridExecutable::get_placement(const runtime::Tensor* t)
{
size_t index = 0;
for (const shared_ptr<ngraph::runtime::Backend>& be : m_backend_list)
......
......@@ -30,6 +30,7 @@ namespace ngraph
namespace hybrid
{
class HybridBackend;
class HybridExecutable;
}
}
}
......@@ -48,29 +49,37 @@ public:
const ngraph::Shape& shape,
void* memory_pointer) override;
Handle compile(std::shared_ptr<ngraph::Function> func) override;
bool call(std::shared_ptr<ngraph::Function> func,
const std::vector<std::shared_ptr<ngraph::runtime::Tensor>>& outputs,
const std::vector<std::shared_ptr<ngraph::runtime::Tensor>>& inputs) override;
std::shared_ptr<Executable> compile(std::shared_ptr<ngraph::Function> func,
bool enable_performance_data = false) override;
bool is_supported(const ngraph::Node& node) const override;
void set_debug_enabled(bool flag) { m_debug_enabled = flag; }
private:
class FunctionInstance
{
public:
std::vector<std::shared_ptr<runtime::Backend>> m_backend_list;
bool m_debug_enabled = false;
};
class ngraph::runtime::hybrid::HybridExecutable : public runtime::Executable
{
public:
HybridExecutable(const std::vector<std::shared_ptr<runtime::Backend>>& backend_list,
const std::shared_ptr<Function>& func,
bool enable_performance_collection = false,
bool debug_enabled = false);
bool call(const std::vector<std::shared_ptr<ngraph::runtime::Tensor>>& outputs,
const std::vector<std::shared_ptr<ngraph::runtime::Tensor>>& inputs) override;
private:
std::shared_ptr<ngraph::Function> m_function;
std::vector<std::shared_ptr<ngraph::Function>> m_sub_functions;
std::unordered_map<std::shared_ptr<ngraph::op::Parameter>,
std::shared_ptr<ngraph::op::Result>>
std::unordered_map<std::shared_ptr<ngraph::op::Parameter>, std::shared_ptr<ngraph::op::Result>>
m_map_parameter_to_result;
};
std::map<std::shared_ptr<ngraph::Function>, FunctionInstance> m_function_map;
std::vector<std::shared_ptr<runtime::Backend>> m_backend_list;
bool m_debug_enabled = false;
std::unordered_map<std::shared_ptr<Function>, std::shared_ptr<Executable>> m_executable_map;
size_t get_placement(const runtime::Tensor* t);
};
......@@ -31,6 +31,7 @@ namespace ngraph
namespace intelgpu
{
class IntelGPUBackend;
class IntelGPUExecutable;
}
}
}
......@@ -47,33 +48,51 @@ public:
std::shared_ptr<ngraph::runtime::Tensor>
create_tensor(const ngraph::element::Type& element_type, const Shape& shape) override;
Handle compile(std::shared_ptr<Function> func) override;
std::shared_ptr<runtime::Executable> compile(std::shared_ptr<Function> func,
bool enable_timing = false) override;
void remove_compiled_function(std::shared_ptr<runtime::Executable> exec) override;
bool call(std::shared_ptr<Function> func,
const std::vector<std::shared_ptr<runtime::Tensor>>& outputs,
const std::vector<std::shared_ptr<runtime::Tensor>>& inputs) override;
bool is_supported_property(const Property prop) const override;
void remove_compiled_function(std::shared_ptr<Function> func) override;
void enable_performance_data(std::shared_ptr<Function> func, bool enable) override;
std::vector<PerformanceCounter>
get_performance_data(std::shared_ptr<Function> func) const override;
private:
std::shared_ptr<cldnn::engine> cldnn_engine;
std::map<std::shared_ptr<Function>, std::shared_ptr<runtime::Executable>> cldnn_networks;
bool is_supported_property(const Property prop) const override;
bool m_profile_enable = false;
long m_profile_lines_limit_count = 10;
bool m_dump_graph_enable = false;
bool m_cldnn_graph_optimize = true;
bool m_cldnn_dump_enable = false;
bool m_function_cache_disabled = false;
bool m_disable_backend_optimizations = false;
std::string m_cldnn_dump_dir = std::string("intelgpu_codegen");
};
class ngraph::runtime::intelgpu::IntelGPUExecutable : public runtime::Executable
{
public:
IntelGPUExecutable(std::shared_ptr<Function> func,
std::shared_ptr<cldnn::network> network,
bool enable_timing,
bool enable_profile,
double compilation_time,
double consumed_memory,
size_t profile_lines_limit_count);
bool call(const std::vector<std::shared_ptr<runtime::Tensor>>& outputs,
const std::vector<std::shared_ptr<runtime::Tensor>>& inputs) override;
std::vector<PerformanceCounter> get_performance_data() const override;
private:
class FunctionInstance
{
public:
std::shared_ptr<cldnn::network> ocl_network = nullptr;
std::shared_ptr<Function> m_function;
std::shared_ptr<cldnn::network> m_cldnn_network = nullptr;
bool m_performance_counters_enabled = false;
bool m_profile_enable = false;
double m_compilation_time = 0.0;
double m_consumed_memory = 0.0;
};
std::map<std::shared_ptr<Function>, FunctionInstance> ocl_networks;
std::shared_ptr<cldnn::engine> ocl_engine;
bool m_disable_backend_optimizations = false;
long m_profile_lines_limit_count = 10;
std::string delim = std::string(":");
// Statistic related things
void print_call_performance(const std::shared_ptr<cldnn::network> network,
......@@ -83,13 +102,4 @@ private:
double mem_compilation_consumed,
double mem_call_consumed,
double mem_current) const;
bool m_profile_enable = false;
long m_profile_lines_limit_count = 10;
bool m_dump_graph_enable = false;
bool m_cldnn_graph_optimize = true;
bool m_cldnn_dump_enable = false;
bool m_function_cache_disabled = false;
std::string m_cldnn_dump_dir = std::string("intelgpu_codegen");
std::string delim = std::string(":");
};
......@@ -64,11 +64,18 @@ shared_ptr<runtime::Tensor> runtime::interpreter::INTBackend::create_tensor(
return make_shared<runtime::HostTensor>(type, shape, memory_pointer, this);
}
runtime::Handle runtime::interpreter::INTBackend::compile(shared_ptr<Function> function)
shared_ptr<runtime::Executable>
runtime::interpreter::INTBackend::compile(shared_ptr<Function> function,
bool enable_performance_collection)
{
return make_shared<INTExecutable>(function, enable_performance_collection);
}
runtime::interpreter::INTExecutable::INTExecutable(const shared_ptr<Function>& function,
bool enable_performance_collection)
{
FunctionInstance& instance = m_function_map[function];
if (!instance.m_is_compiled)
{
FunctionInstance& instance = m_function_instance;
instance.m_is_compiled = true;
pass::Manager pass_manager;
pass_manager.register_pass<pass::LikeReplacement>();
......@@ -81,24 +88,13 @@ runtime::Handle runtime::interpreter::INTBackend::compile(shared_ptr<Function> f
instance.m_wrapped_nodes.emplace_back(node);
}
}
return function;
set_parameters_and_results(*function);
}
bool runtime::interpreter::INTBackend::call(shared_ptr<Function> function,
const vector<shared_ptr<runtime::Tensor>>& outputs,
bool runtime::interpreter::INTExecutable::call(const vector<shared_ptr<runtime::Tensor>>& outputs,
const vector<shared_ptr<runtime::Tensor>>& inputs)
{
auto fit = m_function_map.find(function);
if (fit == m_function_map.end())
{
throw runtime_error("compile() must be called before call().");
}
FunctionInstance& instance = fit->second;
if (!instance.m_is_compiled)
{
throw runtime_error("compile() must be called before call().");
}
FunctionInstance& instance = m_function_instance;
// convert inputs to HostTensor
vector<shared_ptr<HostTensor>> func_inputs;
......@@ -123,7 +119,7 @@ bool runtime::interpreter::INTBackend::call(shared_ptr<Function> function,
// map function params -> HostTensor
unordered_map<descriptor::Tensor*, shared_ptr<HostTensor>> tensor_map;
size_t input_count = 0;
for (auto param : function->get_parameters())
for (auto param : get_parameters())
{
for (size_t i = 0; i < param->get_output_size(); ++i)
{
......@@ -133,9 +129,9 @@ bool runtime::interpreter::INTBackend::call(shared_ptr<Function> function,
}
// map function outputs -> HostTensor
for (size_t output_count = 0; output_count < function->get_output_size(); ++output_count)
for (size_t output_count = 0; output_count < get_results().size(); ++output_count)
{
auto output = function->get_output_op(output_count);
auto output = get_results()[output_count];
if (!dynamic_pointer_cast<op::Result>(output))
{
throw ngraph_error("One of function's outputs isn't op::Result");
......@@ -229,7 +225,8 @@ bool runtime::interpreter::INTBackend::call(shared_ptr<Function> function,
return true;
}
void runtime::interpreter::INTBackend::generate_calls(const element::Type& type,
void runtime::interpreter::INTExecutable::generate_calls(
const element::Type& type,
const NodeWrapper& op,
const vector<shared_ptr<HostTensor>>& outputs,
const vector<shared_ptr<HostTensor>>& inputs,
......@@ -267,24 +264,17 @@ void runtime::interpreter::INTBackend::generate_calls(const element::Type& type,
}
}
void runtime::interpreter::INTBackend::set_nan_check(shared_ptr<Function> func, bool enable)
void runtime::interpreter::INTExecutable::set_nan_check(bool enable)
{
FunctionInstance& instance = m_function_map[func];
FunctionInstance& instance = m_function_instance;
instance.m_nan_check_enabled = enable;
}
void runtime::interpreter::INTBackend::enable_performance_data(shared_ptr<Function> func,
bool enable)
{
FunctionInstance& instance = m_function_map[func];
instance.m_performance_counters_enabled = enable;
}
vector<runtime::PerformanceCounter>
runtime::interpreter::INTBackend::get_performance_data(shared_ptr<Function> func) const
runtime::interpreter::INTExecutable::get_performance_data() const
{
vector<runtime::PerformanceCounter> rc;
const FunctionInstance& instance = m_function_map.at(func);
const FunctionInstance& instance = m_function_instance;
for (const pair<const Node*, stopwatch> p : instance.m_timer_map)
{
rc.emplace_back(p.first->get_name().c_str(),
......@@ -294,7 +284,7 @@ vector<runtime::PerformanceCounter>
return rc;
}
void runtime::interpreter::INTBackend::perform_nan_check(
void runtime::interpreter::INTExecutable::perform_nan_check(
const vector<shared_ptr<HostTensor>>& tensors, const Node* op)
{
size_t arg_number = 1;
......
......@@ -143,6 +143,7 @@ namespace ngraph
namespace interpreter
{
class INTBackend;
class INTExecutable;
}
} // namespace runtime
} // namespace ngraph
......@@ -161,19 +162,27 @@ public:
std::shared_ptr<Tensor> create_tensor(const element::Type& type, const Shape& shape) override;
Handle compile(std::shared_ptr<Function> function) override;
std::shared_ptr<Executable> compile(std::shared_ptr<Function> function,
bool enable_performance_data = false) override;
bool call(std::shared_ptr<Function> function,
const std::vector<std::shared_ptr<Tensor>>& outputs,
const std::vector<std::shared_ptr<Tensor>>& intputs) override;
bool is_supported(const Node& node) const override;
void set_nan_check(std::shared_ptr<Function> func, bool);
private:
std::set<std::string> m_unsupported_op_name_list;
};
void enable_performance_data(std::shared_ptr<Function> func, bool enable) override;
std::vector<PerformanceCounter>
get_performance_data(std::shared_ptr<Function> func) const override;
class ngraph::runtime::interpreter::INTExecutable : public Executable
{
public:
INTExecutable(const std::shared_ptr<Function>& function,
bool enable_performance_collection = false);
bool is_supported(const Node& node) const override;
bool call(const std::vector<std::shared_ptr<Tensor>>& outputs,
const std::vector<std::shared_ptr<Tensor>>& intputs) override;
void set_nan_check(bool enable);
std::vector<PerformanceCounter> get_performance_data() const override;
private:
int get_alignment() const { return 64; }
......@@ -186,8 +195,7 @@ private:
std::unordered_map<const Node*, stopwatch> m_timer_map;
std::vector<NodeWrapper> m_wrapped_nodes;
std::unordered_map<const Node*, std::shared_ptr<RNGState>> m_states;
};
std::map<std::shared_ptr<Function>, FunctionInstance> m_function_map;
} m_function_instance;
std::set<std::string> m_unsupported_op_name_list;
static void perform_nan_check(const std::vector<std::shared_ptr<HostTensor>>&,
......
......@@ -54,13 +54,24 @@ shared_ptr<runtime::Tensor> runtime::nop::NOPBackend::create_tensor(const elemen
return make_shared<runtime::HostTensor>(type, shape, memory_pointer, "external");
}
runtime::Handle runtime::nop::NOPBackend::compile(shared_ptr<Function> function)
shared_ptr<runtime::Executable>
runtime::nop::NOPBackend::compile(shared_ptr<Function> function,
bool enable_performance_collection)
{
return function;
return make_shared<NOPExecutable>(function, enable_performance_collection);
}
bool runtime::nop::NOPBackend::call(shared_ptr<Function> function,
const vector<shared_ptr<runtime::Tensor>>& outputs,
runtime::nop::NOPExecutable::NOPExecutable(shared_ptr<Function> function,
bool enable_performance_collection)
{
pass::Manager pass_manager;
pass_manager.register_pass<pass::AssignLayout<DenseTensorLayout>>();
pass_manager.run_passes(function);
set_parameters_and_results(*function);
}
bool runtime::nop::NOPExecutable::call(const vector<shared_ptr<runtime::Tensor>>& outputs,
const vector<shared_ptr<runtime::Tensor>>& inputs)
{
return true;
......
......@@ -32,6 +32,7 @@ namespace ngraph
namespace nop
{
class NOPBackend;
class NOPExecutable;
}
}
}
......@@ -44,9 +45,14 @@ public:
std::shared_ptr<Tensor> create_tensor(const element::Type& type, const Shape& shape) override;
Handle compile(std::shared_ptr<Function> function) override;
std::shared_ptr<Executable> compile(std::shared_ptr<Function> function,
bool enable_performance_data = false) override;
};
bool call(std::shared_ptr<Function> function,
const std::vector<std::shared_ptr<Tensor>>& outputs,
const std::vector<std::shared_ptr<Tensor>>& intputs) override;
class ngraph::runtime::nop::NOPExecutable : public Executable
{
public:
NOPExecutable(std::shared_ptr<Function> function, bool enable_performance_collection = false);
bool call(const std::vector<std::shared_ptr<runtime::Tensor>>& outputs,
const std::vector<std::shared_ptr<runtime::Tensor>>& inputs) override;
};
......@@ -136,8 +136,7 @@ vector<runtime::PerformanceCounter> run_benchmark(shared_ptr<Function> f,
stopwatch timer;
timer.start();
auto backend = runtime::Backend::create(backend_name);
backend->enable_performance_data(f, timing_detail);
auto compiled_func = backend->compile(f);
auto compiled_func = backend->compile(f, timing_detail);
timer.stop();
cout.imbue(locale(""));
cout << "compile time: " << timer.get_milliseconds() << "ms" << endl;
......@@ -183,7 +182,7 @@ vector<runtime::PerformanceCounter> run_benchmark(shared_ptr<Function> f,
{
for (int i = 0; i < warmup_iterations; i++)
{
backend->call(compiled_func, results, args);
compiled_func->call(results, args);
}
}
......@@ -205,7 +204,7 @@ vector<runtime::PerformanceCounter> run_benchmark(shared_ptr<Function> f,
}
}
}
backend->call(compiled_func, results, args);
compiled_func->call(results, args);
if (copy_data)
{
for (size_t result_index = 0; result_index < results.size(); result_index++)
......@@ -222,6 +221,6 @@ vector<runtime::PerformanceCounter> run_benchmark(shared_ptr<Function> f,
float time = t1.get_milliseconds();
cout << time / iterations << "ms per iteration" << endl;
vector<runtime::PerformanceCounter> perf_data = backend->get_performance_data(f);
vector<runtime::PerformanceCounter> perf_data = compiled_func->get_performance_data();
return perf_data;
}
......@@ -37,9 +37,6 @@ TEST(INTERPRETER, nan_check_input)
shared_ptr<runtime::Backend> backend = runtime::Backend::create("INTERPRETER");
shared_ptr<runtime::interpreter::INTBackend> ibackend =
static_pointer_cast<runtime::interpreter::INTBackend>(backend);
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{2, 4, NAN, 16});
......@@ -47,9 +44,12 @@ TEST(INTERPRETER, nan_check_input)
copy_data(b, vector<float>{1, 2, 1, 8});
auto result = backend->create_tensor(element::f32, shape);
auto handle = backend->compile(f);
ibackend->set_nan_check(handle, true);
EXPECT_ANY_THROW(ibackend->call_with_validate(handle, {result}, {a, b}));
shared_ptr<runtime::Executable> handle = backend->compile(f);
shared_ptr<runtime::interpreter::INTExecutable> ihandle =
static_pointer_cast<runtime::interpreter::INTExecutable>(handle);
ihandle->set_nan_check(true);
EXPECT_ANY_THROW(handle->call_with_validate({result}, {a, b}));
}
TEST(INTERPRETER, nan_check_output)
......@@ -61,9 +61,6 @@ TEST(INTERPRETER, nan_check_output)
shared_ptr<runtime::Backend> backend = runtime::Backend::create("INTERPRETER");
shared_ptr<runtime::interpreter::INTBackend> ibackend =
static_pointer_cast<runtime::interpreter::INTBackend>(backend);
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{2, 4, 0, 16});
......@@ -71,7 +68,9 @@ TEST(INTERPRETER, nan_check_output)
copy_data(b, vector<float>{1, 2, 0, 8});
auto result = backend->create_tensor(element::f32, shape);
auto handle = backend->compile(f);
ibackend->set_nan_check(handle, true);
EXPECT_ANY_THROW(ibackend->call_with_validate(handle, {result}, {a, b}));
shared_ptr<runtime::Executable> handle = backend->compile(f);
shared_ptr<runtime::interpreter::INTExecutable> ihandle =
static_pointer_cast<runtime::interpreter::INTExecutable>(handle);
ihandle->set_nan_check(true);
EXPECT_ANY_THROW(handle->call_with_validate({result}, {a, b}));
}
......@@ -6728,3 +6728,30 @@ NGRAPH_TEST(${BACKEND_NAME}, quantize_dynamic_offset)
EXPECT_EQ((vector<output_c_type>{1, 1, 2, 3, 3, 3, 4, 5, 5, 5, 6, 7}),
read_vector<output_c_type>(y));
}
NGRAPH_TEST(${BACKEND_NAME}, get_parameters_and_results)
{
Shape shape{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::Parameter>(element::f32, shape);
auto C = make_shared<op::Parameter>(element::f32, shape);
auto f = make_shared<Function>((A + B) * C, ParameterVector{A, B, C});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
shared_ptr<runtime::Tensor> a = backend->create_tensor(element::f32, shape);
shared_ptr<runtime::Tensor> b = backend->create_tensor(element::f32, shape);
shared_ptr<runtime::Tensor> c = backend->create_tensor(element::f32, shape);
shared_ptr<runtime::Tensor> result = backend->create_tensor(element::f32, shape);
copy_data(a, test::NDArray<float, 2>({{1, 2}, {3, 4}}).get_vector());
copy_data(b, test::NDArray<float, 2>({{5, 6}, {7, 8}}).get_vector());
copy_data(c, test::NDArray<float, 2>({{9, 10}, {11, 12}}).get_vector());
auto handle = backend->compile(f);
auto parameters = handle->get_parameters();
auto results = handle->get_results();
EXPECT_EQ(parameters.size(), 3);
EXPECT_EQ(results.size(), 1);
}
......@@ -61,8 +61,8 @@ TEST(debugger, add_breakpoint)
copy_data(a, dataA);
copy_data(b, dataB);
auto cf =
std::dynamic_pointer_cast<ngraph::runtime::cpu::CPU_Backend>(backend)->get_call_frame(f);
shared_ptr<runtime::Executable> handle = backend->compile(f);
auto cf = dynamic_pointer_cast<runtime::cpu::CPU_Executable>(handle)->get_call_frame();
ngraph::runtime::cpu::CPU_Debugger dbg(*cf);
......@@ -97,8 +97,8 @@ TEST(debugger, stepping)
copy_data(a, dataA);
copy_data(b, dataB);
auto cf =
std::dynamic_pointer_cast<ngraph::runtime::cpu::CPU_Backend>(backend)->get_call_frame(f);
shared_ptr<runtime::Executable> handle = backend->compile(f);
auto cf = dynamic_pointer_cast<runtime::cpu::CPU_Executable>(handle)->get_call_frame();
ngraph::runtime::cpu::CPU_Debugger dbg(*cf);
......@@ -134,8 +134,8 @@ TEST(debugger, delete_breakpoint)
copy_data(a, dataA);
copy_data(b, dataB);
auto cf =
std::dynamic_pointer_cast<ngraph::runtime::cpu::CPU_Backend>(backend)->get_call_frame(f);
shared_ptr<runtime::Executable> handle = backend->compile(f);
auto cf = dynamic_pointer_cast<runtime::cpu::CPU_Executable>(handle)->get_call_frame();
ngraph::runtime::cpu::CPU_Debugger dbg(*cf);
......@@ -174,8 +174,8 @@ TEST(debugger, while_stepping)
copy_data(a, dataA);
copy_data(b, dataB);
auto cf =
std::dynamic_pointer_cast<ngraph::runtime::cpu::CPU_Backend>(backend)->get_call_frame(f);
shared_ptr<runtime::Executable> handle = backend->compile(f);
auto cf = dynamic_pointer_cast<runtime::cpu::CPU_Executable>(handle)->get_call_frame();
ngraph::runtime::cpu::CPU_Debugger dbg(*cf);
......@@ -212,8 +212,8 @@ TEST(debugger, resume)
copy_data(a, dataA);
copy_data(b, dataB);
auto cf =
std::dynamic_pointer_cast<ngraph::runtime::cpu::CPU_Backend>(backend)->get_call_frame(f);
shared_ptr<runtime::Executable> handle = backend->compile(f);
auto cf = dynamic_pointer_cast<runtime::cpu::CPU_Executable>(handle)->get_call_frame();
ngraph::runtime::cpu::CPU_Debugger dbg(*cf);
......@@ -248,8 +248,8 @@ TEST(tracer, basic)
copy_data(a, dataA);
copy_data(b, dataB);
auto cf =
std::dynamic_pointer_cast<ngraph::runtime::cpu::CPU_Backend>(backend)->get_call_frame(f);
shared_ptr<runtime::Executable> handle = backend->compile(f);
auto cf = dynamic_pointer_cast<runtime::cpu::CPU_Executable>(handle)->get_call_frame();
ngraph::runtime::cpu::CPU_Debugger dbg(*cf);
......@@ -281,8 +281,8 @@ TEST(tracer, count_tracepoint)
shared_ptr<runtime::Tensor> b = backend->create_tensor(element::i32, shape);
shared_ptr<runtime::Tensor> result = backend->create_tensor(element::i32, shape);
auto cf =
std::dynamic_pointer_cast<ngraph::runtime::cpu::CPU_Backend>(backend)->get_call_frame(f);
shared_ptr<runtime::Executable> handle = backend->compile(f);
auto cf = dynamic_pointer_cast<runtime::cpu::CPU_Executable>(handle)->get_call_frame();
ngraph::runtime::cpu::CPU_Debugger dbg(*cf);
......@@ -322,8 +322,8 @@ TEST(tracer, conditional_tracepoint)
shared_ptr<runtime::Tensor> b = backend->create_tensor(element::i32, shape);
shared_ptr<runtime::Tensor> result = backend->create_tensor(element::i32, shape);
auto cf =
std::dynamic_pointer_cast<ngraph::runtime::cpu::CPU_Backend>(backend)->get_call_frame(f);
shared_ptr<runtime::Executable> handle = backend->compile(f);
auto cf = dynamic_pointer_cast<runtime::cpu::CPU_Executable>(handle)->get_call_frame();
ngraph::runtime::cpu::CPU_Debugger dbg(*cf);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment