Unverified Commit 777600c6 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

New backend/transformer API (#739)

* force backend compile() to make a copy of the graph

fix copy_with_new_args on ops that have function pointers internal

update unit test for new backend API

add unit test for multiple simulataneous backends

* move get_subdevices virtual method to Manager class

* update GPU to latest

* update call methods

* add remove_compiled_function()
parent ca4a83ea
...@@ -219,7 +219,7 @@ int main(int argc, const char* argv[]) ...@@ -219,7 +219,7 @@ int main(int argc, const char* argv[])
// X, Y, learning_rate, W0, b0, W1, b1 -> loss, softmax, W0_next, b0_next, W1_next, b1_next // X, Y, learning_rate, W0, b0, W1, b1 -> loss, softmax, W0_next, b0_next, W1_next, b1_next
NodeMap train_node_map; NodeMap train_node_map;
auto train_function = clone_function( auto train_function = clone_function(
std::make_shared<Function>( Function(
NodeVector{loss, softmax, W0_next, b0_next, W1_next, b1_next}, NodeVector{loss, softmax, W0_next, b0_next, W1_next, b1_next},
op::ParameterVector{X, Y, N, learning_rate, W0, b0, W1, b1}), op::ParameterVector{X, Y, N, learning_rate, W0, b0, W1, b1}),
train_node_map); train_node_map);
...@@ -229,10 +229,10 @@ int main(int argc, const char* argv[]) ...@@ -229,10 +229,10 @@ int main(int argc, const char* argv[])
// Plain inference // Plain inference
// X, W0, b0, W1, b1 -> softmax // X, W0, b0, W1, b1 -> softmax
NodeMap inference_node_map; NodeMap inference_node_map;
auto inference_function = clone_function( auto inference_function =
std::make_shared<Function>(NodeVector{softmax}, clone_function(Function(NodeVector{softmax},
op::ParameterVector{X, W0, b0, W1, b1}), op::ParameterVector{X, W0, b0, W1, b1}),
inference_node_map); inference_node_map);
auto inference_ext = manager->compile(inference_function); auto inference_ext = manager->compile(inference_function);
auto inference_cf = backend->make_call_frame(inference_ext); auto inference_cf = backend->make_call_frame(inference_ext);
......
...@@ -118,6 +118,7 @@ set (SRC ...@@ -118,6 +118,7 @@ set (SRC
pass/core_fusion.cpp pass/core_fusion.cpp
pattern/matcher.cpp pattern/matcher.cpp
runtime/aligned_buffer.cpp runtime/aligned_buffer.cpp
runtime/backend.cpp
runtime/host_tensor_view.cpp runtime/host_tensor_view.cpp
runtime/interpreter/int_backend.cpp runtime/interpreter/int_backend.cpp
runtime/interpreter/int_call_frame.cpp runtime/interpreter/int_call_frame.cpp
......
...@@ -226,15 +226,21 @@ std::list<std::shared_ptr<ngraph::Node>> ...@@ -226,15 +226,21 @@ std::list<std::shared_ptr<ngraph::Node>>
return cloned_nodes; return cloned_nodes;
} }
std::shared_ptr<ngraph::Function> ngraph::clone_function(std::shared_ptr<ngraph::Function> func, std::shared_ptr<ngraph::Function> ngraph::clone_function(const ngraph::Function& func)
{
NodeMap nm;
return clone_function(func, nm);
}
std::shared_ptr<ngraph::Function> ngraph::clone_function(const ngraph::Function& func,
NodeMap& node_map) NodeMap& node_map)
{ {
// clone function operations // clone function operations
clone_nodes(func->get_ops(), node_map); clone_nodes(func.get_ops(), node_map);
// get cloned function results and parameters // get cloned function results and parameters
ResultVector cloned_results; ResultVector cloned_results;
for (shared_ptr<Node> node : func->get_results()) for (shared_ptr<Node> node : func.get_results())
{ {
auto result = std::dynamic_pointer_cast<op::Result>(node_map.get(node)); auto result = std::dynamic_pointer_cast<op::Result>(node_map.get(node));
if (!result) if (!result)
...@@ -244,7 +250,7 @@ std::shared_ptr<ngraph::Function> ngraph::clone_function(std::shared_ptr<ngraph: ...@@ -244,7 +250,7 @@ std::shared_ptr<ngraph::Function> ngraph::clone_function(std::shared_ptr<ngraph:
cloned_results.push_back(result); cloned_results.push_back(result);
} }
std::vector<std::shared_ptr<op::Parameter>> cloned_params; std::vector<std::shared_ptr<op::Parameter>> cloned_params;
for (auto param : func->get_parameters()) for (auto param : func.get_parameters())
{ {
cloned_params.push_back(std::dynamic_pointer_cast<op::Parameter>(node_map.get(param))); cloned_params.push_back(std::dynamic_pointer_cast<op::Parameter>(node_map.get(param)));
} }
......
...@@ -105,9 +105,12 @@ namespace ngraph ...@@ -105,9 +105,12 @@ namespace ngraph
// input function is cloned and returned // input function is cloned and returned
// NodeMap input may contain default node mapping i.e. pre-cloned nodes // NodeMap input may contain default node mapping i.e. pre-cloned nodes
// NodeMap output (by reference) fully maps input and cloned function ops // NodeMap output (by reference) fully maps input and cloned function ops
std::shared_ptr<ngraph::Function> clone_function(std::shared_ptr<ngraph::Function> func, std::shared_ptr<ngraph::Function> clone_function(const ngraph::Function& func,
NodeMap& node_map); NodeMap& node_map);
// input function is cloned and returned
std::shared_ptr<ngraph::Function> clone_function(const ngraph::Function& func);
// Assert that nodes in the function is colocated and return that placement // Assert that nodes in the function is colocated and return that placement
Placement get_colocated_function_placement(std::shared_ptr<Function> func); Placement get_colocated_function_placement(std::shared_ptr<Function> func);
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "ngraph/op/function_call.hpp" #include "ngraph/op/function_call.hpp"
#include "ngraph/function.hpp" #include "ngraph/function.hpp"
#include "ngraph/graph_util.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
...@@ -51,7 +52,9 @@ op::FunctionCall::FunctionCall(shared_ptr<Function> function, const NodeVector& ...@@ -51,7 +52,9 @@ op::FunctionCall::FunctionCall(shared_ptr<Function> function, const NodeVector&
shared_ptr<Node> op::FunctionCall::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::FunctionCall::copy_with_new_args(const NodeVector& new_args) const
{ {
return make_shared<FunctionCall>(m_function, new_args); shared_ptr<FunctionCall> fc = make_shared<FunctionCall>(m_function, new_args);
fc->m_function = clone_function(*m_function);
return fc;
} }
/// \return A singleton vector containing the function to be called. /// \return A singleton vector containing the function to be called.
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "ngraph/op/reduce.hpp" #include "ngraph/op/reduce.hpp"
#include "ngraph/function.hpp" #include "ngraph/function.hpp"
#include "ngraph/graph_util.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
...@@ -99,6 +100,8 @@ shared_ptr<Node> op::Reduce::copy_with_new_args(const NodeVector& new_args) cons ...@@ -99,6 +100,8 @@ shared_ptr<Node> op::Reduce::copy_with_new_args(const NodeVector& new_args) cons
{ {
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
} }
return make_shared<Reduce>( shared_ptr<Reduce> fc =
new_args.at(0), new_args.at(1), m_reduction_function, m_reduction_axes); make_shared<Reduce>(new_args.at(0), new_args.at(1), m_reduction_function, m_reduction_axes);
fc->m_reduction_function = clone_function(*m_reduction_function);
return fc;
} }
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "ngraph/op/reduce_window.hpp" #include "ngraph/op/reduce_window.hpp"
#include "ngraph/function.hpp" #include "ngraph/function.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
using namespace std; using namespace std;
...@@ -136,9 +137,11 @@ shared_ptr<Node> op::ReduceWindow::copy_with_new_args(const NodeVector& new_args ...@@ -136,9 +137,11 @@ shared_ptr<Node> op::ReduceWindow::copy_with_new_args(const NodeVector& new_args
{ {
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
} }
return make_shared<ReduceWindow>(new_args.at(0), auto node = make_shared<ReduceWindow>(new_args.at(0),
new_args.at(1), new_args.at(1),
m_reduction_function, m_reduction_function,
m_window_shape, m_window_shape,
m_window_movement_strides); m_window_movement_strides);
node->m_reduction_function = clone_function(*m_reduction_function);
return node;
} }
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "ngraph/op/select_and_scatter.hpp" #include "ngraph/op/select_and_scatter.hpp"
#include "ngraph/function.hpp" #include "ngraph/function.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/parameter.hpp" #include "ngraph/op/parameter.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
...@@ -223,11 +224,14 @@ shared_ptr<Node> op::SelectAndScatter::copy_with_new_args(const NodeVector& new_ ...@@ -223,11 +224,14 @@ shared_ptr<Node> op::SelectAndScatter::copy_with_new_args(const NodeVector& new_
{ {
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
} }
return make_shared<SelectAndScatter>(new_args.at(0), auto node = make_shared<SelectAndScatter>(new_args.at(0),
new_args.at(1), new_args.at(1),
new_args.at(2), new_args.at(2),
m_selection_function, m_selection_function,
m_scatter_function, m_scatter_function,
m_window_shape, m_window_shape,
m_window_movement_strides); m_window_movement_strides);
node->m_selection_function = clone_function(*m_selection_function);
node->m_scatter_function = clone_function(*m_scatter_function);
return node;
} }
...@@ -68,7 +68,7 @@ bool ngraph::pass::Inliner::inline_function_call(std::shared_ptr<ngraph::Node> i ...@@ -68,7 +68,7 @@ bool ngraph::pass::Inliner::inline_function_call(std::shared_ptr<ngraph::Node> i
nm.add(callee->get_parameters().at(i), callsite->get_input_op(i)); nm.add(callee->get_parameters().at(i), callsite->get_input_op(i));
} }
ngraph::clone_function(callee, nm); ngraph::clone_function(*callee, nm);
auto callee_graph = nm.get(callee->get_result()); auto callee_graph = nm.get(callee->get_result());
caller->replace_node(callsite, callee_graph); caller->replace_node(callsite, callee_graph);
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/runtime/backend.hpp"
#include "ngraph/runtime/cpu/cpu_tensor_view.hpp"
#include "ngraph/runtime/manager.hpp"
using namespace std;
using namespace ngraph;
std::shared_ptr<runtime::Backend> runtime::Backend::create(const std::string& type)
{
std::shared_ptr<Manager> manager = runtime::Manager::get(type);
return manager->allocate_backend();
}
vector<string> runtime::Backend::get_registered_devices()
{
vector<string> rc;
for (const pair<string, runtime::Manager::Factory>& p : runtime::Manager::get_factory_map())
{
rc.push_back(p.first);
}
return rc;
}
vector<size_t> runtime::Backend::get_subdevices(const string& type)
{
std::shared_ptr<Manager> manager = runtime::Manager::get(type);
return manager->get_subdevices();
}
void runtime::Backend::remove_compiled_function(const Function& func)
{
}
...@@ -18,16 +18,12 @@ ...@@ -18,16 +18,12 @@
#include <memory> #include <memory>
#include "ngraph/function.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp" #include "ngraph/type/element_type.hpp"
namespace ngraph namespace ngraph
{ {
namespace element
{
class Type;
}
namespace runtime namespace runtime
{ {
class ExternalFunction; class ExternalFunction;
...@@ -44,26 +40,66 @@ namespace ngraph ...@@ -44,26 +40,66 @@ namespace ngraph
/// @brief Make a call frame that can support one concurrent call of an external function. /// @brief Make a call frame that can support one concurrent call of an external function.
/// ///
/// If more than one concurrent execution is needed, each execution will require its own call frame. /// If more than one concurrent execution is needed, each execution will require its own call frame.
/// DEPRECATED
virtual std::shared_ptr<ngraph::runtime::CallFrame> virtual std::shared_ptr<ngraph::runtime::CallFrame>
make_call_frame(const std::shared_ptr<ExternalFunction>& external_function) = 0; make_call_frame(const std::shared_ptr<ExternalFunction>& external_function) = 0;
/// @brief Return a handle for a tensor on the backend device. /// @brief Return a handle for a tensor on the backend device.
/// DEPRECATED
virtual std::shared_ptr<ngraph::runtime::TensorView> virtual std::shared_ptr<ngraph::runtime::TensorView>
make_primary_tensor_view(const ngraph::element::Type& element_type, make_primary_tensor_view(const ngraph::element::Type& element_type,
const Shape& shape) = 0; const Shape& shape) = 0;
/// DEPRECATED
template <typename T>
std::shared_ptr<ngraph::runtime::TensorView>
make_primary_tensor_view(const Shape& shape)
{
return make_primary_tensor_view(element::from<T>(), shape);
}
/// @brief Return a handle for a tensor for given mem on backend device /// @brief Return a handle for a tensor for given mem on backend device
virtual std::shared_ptr<ngraph::runtime::TensorView> virtual std::shared_ptr<ngraph::runtime::TensorView>
make_primary_tensor_view(const ngraph::element::Type& element_type, make_primary_tensor_view(const ngraph::element::Type& element_type,
const Shape& shape, const Shape& shape,
void* memory_pointer) = 0; void* memory_pointer) = 0;
/// DEPRECATED
virtual bool call(const std::vector<std::shared_ptr<runtime::TensorView>>& outputs,
const std::vector<std::shared_ptr<runtime::TensorView>>& inputs) = 0;
/// @brief Create a new Backend object
/// @param type The name of a registered backend, such as "CPU" or "GPU".
/// To select a subdevice use "GPU:N" where s`N` is the subdevice number.
/// @returns shared_ptr to a new Backend or nullptr if the named backend
/// does not exist.
static std::shared_ptr<Backend> create(const std::string& type);
/// @brief Query the list of registered devices
/// @returns A vector of all registered devices.
static std::vector<std::string> get_registered_devices();
/// @brief Query the list of available subdevices of a particular device.
/// @param type The name of a registered backend, such as "CPU" or "GPU"
/// @returns A vector of available devices of the specified type.
static std::vector<size_t> get_subdevices(const std::string& type);
virtual std::shared_ptr<ngraph::runtime::TensorView>
create_tensor(const ngraph::element::Type& element_type, const Shape& shape) = 0;
template <typename T> template <typename T>
std::shared_ptr<ngraph::runtime::TensorView> std::shared_ptr<ngraph::runtime::TensorView> create_tensor(const Shape& shape)
make_primary_tensor_view(const Shape& shape)
{ {
return make_primary_tensor_view(element::from<T>(), shape); return create_tensor(element::from<T>(), shape);
} }
virtual bool compile(const ngraph::Function& func) = 0;
virtual bool call(const ngraph::Function& func,
const std::vector<std::shared_ptr<runtime::TensorView>>& outputs,
const std::vector<std::shared_ptr<runtime::TensorView>>& inputs) = 0;
virtual void remove_compiled_function(const ngraph::Function& func);
}; };
} }
} }
...@@ -15,15 +15,19 @@ ...@@ -15,15 +15,19 @@
*******************************************************************************/ *******************************************************************************/
#include "ngraph/runtime/cpu/cpu_backend.hpp" #include "ngraph/runtime/cpu/cpu_backend.hpp"
#include "ngraph/log.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/cpu/cpu_call_frame.hpp"
#include "ngraph/runtime/cpu/cpu_external_function.hpp"
#include "ngraph/runtime/cpu/cpu_tensor_view.hpp" #include "ngraph/runtime/cpu/cpu_tensor_view.hpp"
#include "ngraph/runtime/external_function.hpp" #include "ngraph/runtime/external_function.hpp"
#include "ngraph/util.hpp"
using namespace ngraph; using namespace ngraph;
using namespace std; using namespace std;
std::shared_ptr<ngraph::runtime::CallFrame> runtime::cpu::CPU_Backend::make_call_frame( std::shared_ptr<ngraph::runtime::CallFrame> runtime::cpu::CPU_Backend::make_call_frame(
const std::shared_ptr<ExternalFunction>& external_function) const std::shared_ptr<runtime::ExternalFunction>& external_function)
{ {
return external_function->make_call_frame(); return external_function->make_call_frame();
} }
...@@ -32,8 +36,69 @@ std::shared_ptr<ngraph::runtime::TensorView> ...@@ -32,8 +36,69 @@ std::shared_ptr<ngraph::runtime::TensorView>
runtime::cpu::CPU_Backend::make_primary_tensor_view(const ngraph::element::Type& element_type, runtime::cpu::CPU_Backend::make_primary_tensor_view(const ngraph::element::Type& element_type,
const Shape& shape) const Shape& shape)
{ {
auto rc = make_shared<runtime::cpu::CPUTensorView>(element_type, shape); return make_shared<runtime::cpu::CPUTensorView>(element_type, shape);
return dynamic_pointer_cast<runtime::TensorView>(rc); }
std::shared_ptr<ngraph::runtime::TensorView>
runtime::cpu::CPU_Backend::create_tensor(const ngraph::element::Type& element_type,
const Shape& shape)
{
return make_shared<runtime::cpu::CPUTensorView>(element_type, shape);
}
bool runtime::cpu::CPU_Backend::compile(const ngraph::Function& func)
{
if (!contains_key(m_function_map, &func))
{
FunctionInstance instance;
instance.m_function = clone_function(func);
instance.m_external_function = make_shared<CPU_ExternalFunction>(instance.m_function);
auto cf = instance.m_external_function->make_call_frame();
instance.m_call_frame = dynamic_pointer_cast<CPU_CallFrame>(cf);
m_function_map.insert({&func, instance});
}
return true;
}
bool runtime::cpu::CPU_Backend::call(const Function& func,
const vector<shared_ptr<runtime::TensorView>>& outputs,
const vector<shared_ptr<runtime::TensorView>>& inputs)
{
bool rc = true;
auto it = m_function_map.find(&func);
if (it == m_function_map.end())
{
compile(func);
it = m_function_map.find(&func);
}
if (it == m_function_map.end())
{
throw runtime_error("Error constructing backend.");
}
FunctionInstance& instance = it->second;
instance.m_call_frame->call(outputs, inputs);
return rc;
}
bool runtime::cpu::CPU_Backend::call(
const std::vector<std::shared_ptr<runtime::TensorView>>& outputs,
const std::vector<std::shared_ptr<runtime::TensorView>>& inputs)
{
if (m_function_map.size() != 1)
{
throw runtime_error("This call method only works if a single function is compiled");
}
FunctionInstance& instance = m_function_map.begin()->second;
instance.m_call_frame->call(outputs, inputs);
return true;
}
void runtime::cpu::CPU_Backend::remove_compiled_function(const Function& func)
{
m_function_map.erase(&func);
} }
std::shared_ptr<ngraph::runtime::TensorView> runtime::cpu::CPU_Backend::make_primary_tensor_view( std::shared_ptr<ngraph::runtime::TensorView> runtime::cpu::CPU_Backend::make_primary_tensor_view(
......
...@@ -16,14 +16,21 @@ ...@@ -16,14 +16,21 @@
#pragma once #pragma once
#include <map>
#include <memory>
#include "ngraph/runtime/backend.hpp" #include "ngraph/runtime/backend.hpp"
namespace ngraph namespace ngraph
{ {
namespace runtime namespace runtime
{ {
class CallFrame;
namespace cpu namespace cpu
{ {
class CPU_ExternalFunction;
class CPU_CallFrame;
class CPU_Backend : public runtime::Backend class CPU_Backend : public runtime::Backend
{ {
public: public:
...@@ -39,6 +46,32 @@ namespace ngraph ...@@ -39,6 +46,32 @@ namespace ngraph
make_primary_tensor_view(const ngraph::element::Type& element_type, make_primary_tensor_view(const ngraph::element::Type& element_type,
const Shape& shape, const Shape& shape,
void* memory_pointer) override; void* memory_pointer) override;
std::shared_ptr<ngraph::runtime::TensorView>
create_tensor(const ngraph::element::Type& element_type,
const Shape& shape) override;
bool compile(const ngraph::Function& fun) override;
bool call(const std::vector<std::shared_ptr<runtime::TensorView>>& outputs,
const std::vector<std::shared_ptr<runtime::TensorView>>& inputs) override;
bool call(const ngraph::Function& fun,
const std::vector<std::shared_ptr<runtime::TensorView>>& outputs,
const std::vector<std::shared_ptr<runtime::TensorView>>& inputs) override;
void remove_compiled_function(const Function& func) override;
private:
class FunctionInstance
{
public:
std::shared_ptr<CPU_ExternalFunction> m_external_function;
std::shared_ptr<CPU_CallFrame> m_call_frame;
std::shared_ptr<Function> m_function;
};
std::map<const Function*, FunctionInstance> m_function_map;
}; };
} }
} }
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include "ngraph/runtime/cpu/cpu_external_function.hpp" #include "ngraph/runtime/cpu/cpu_external_function.hpp"
#include "ngraph/runtime/cpu/cpu_manager.hpp" #include "ngraph/runtime/cpu/cpu_manager.hpp"
using namespace std;
using namespace ngraph; using namespace ngraph;
std::shared_ptr<ngraph::runtime::Backend> runtime::cpu::CPU_Manager::allocate_backend() std::shared_ptr<ngraph::runtime::Backend> runtime::cpu::CPU_Manager::allocate_backend()
...@@ -27,6 +28,11 @@ std::shared_ptr<ngraph::runtime::Backend> runtime::cpu::CPU_Manager::allocate_ba ...@@ -27,6 +28,11 @@ std::shared_ptr<ngraph::runtime::Backend> runtime::cpu::CPU_Manager::allocate_ba
return std::make_shared<CPU_Backend>(); return std::make_shared<CPU_Backend>();
} }
vector<size_t> runtime::cpu::CPU_Manager::get_subdevices() const
{
throw runtime_error("unimplemented");
}
std::shared_ptr<ngraph::runtime::ExternalFunction> std::shared_ptr<ngraph::runtime::ExternalFunction>
runtime::cpu::CPU_Manager::compile(const std::shared_ptr<ngraph::Function>& fun) runtime::cpu::CPU_Manager::compile(const std::shared_ptr<ngraph::Function>& fun)
{ {
......
...@@ -40,6 +40,8 @@ namespace ngraph ...@@ -40,6 +40,8 @@ namespace ngraph
public: public:
virtual std::shared_ptr<Backend> allocate_backend() override; virtual std::shared_ptr<Backend> allocate_backend() override;
virtual std::vector<size_t> get_subdevices() const override;
virtual std::shared_ptr<ngraph::runtime::ExternalFunction> virtual std::shared_ptr<ngraph::runtime::ExternalFunction>
compile(const std::shared_ptr<ngraph::Function>& fun) override; compile(const std::shared_ptr<ngraph::Function>& fun) override;
......
...@@ -15,8 +15,11 @@ ...@@ -15,8 +15,11 @@
*******************************************************************************/ *******************************************************************************/
#include "ngraph/runtime/gpu/gpu_backend.hpp" #include "ngraph/runtime/gpu/gpu_backend.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/runtime/external_function.hpp" #include "ngraph/runtime/external_function.hpp"
#include "ngraph/runtime/gpu/gpu_external_function.hpp"
#include "ngraph/runtime/gpu/gpu_tensor_view.hpp" #include "ngraph/runtime/gpu/gpu_tensor_view.hpp"
#include "ngraph/util.hpp"
using namespace ngraph; using namespace ngraph;
using namespace std; using namespace std;
...@@ -35,6 +38,65 @@ std::shared_ptr<ngraph::runtime::TensorView> ...@@ -35,6 +38,65 @@ std::shared_ptr<ngraph::runtime::TensorView>
return dynamic_pointer_cast<runtime::TensorView>(rc); return dynamic_pointer_cast<runtime::TensorView>(rc);
} }
std::shared_ptr<ngraph::runtime::TensorView>
runtime::gpu::GPU_Backend::create_tensor(const ngraph::element::Type& element_type,
const Shape& shape)
{
auto rc = make_shared<runtime::gpu::GPU_TensorView>(element_type, shape);
return dynamic_pointer_cast<runtime::TensorView>(rc);
}
bool runtime::gpu::GPU_Backend::compile(const ngraph::Function& func)
{
if (!contains_key(m_function_map, &func))
{
FunctionInstance instance;
instance.m_function = clone_function(func);
instance.m_external_function = make_shared<GPU_ExternalFunction>(instance.m_function);
auto cf = instance.m_external_function->make_call_frame();
instance.m_call_frame = dynamic_pointer_cast<GPU_CallFrame>(cf);
m_function_map.insert({&func, instance});
}
return true;
}
bool runtime::gpu::GPU_Backend::call(
const ngraph::Function& func,
const std::vector<std::shared_ptr<runtime::TensorView>>& outputs,
const std::vector<std::shared_ptr<runtime::TensorView>>& inputs)
{
bool rc = true;
auto it = m_function_map.find(&func);
if (it == m_function_map.end())
{
compile(func);
it = m_function_map.find(&func);
}
if (it == m_function_map.end())
{
throw runtime_error("Error constructing backend.");
}
FunctionInstance& instance = it->second;
instance.m_call_frame->call(outputs, inputs);
return rc;
}
bool runtime::gpu::GPU_Backend::call(
const std::vector<std::shared_ptr<runtime::TensorView>>& outputs,
const std::vector<std::shared_ptr<runtime::TensorView>>& inputs)
{
if (m_function_map.size() != 1)
{
throw runtime_error("This call method only works if a single function is compiled");
}
FunctionInstance& instance = m_function_map.begin()->second;
instance.m_call_frame->call(outputs, inputs);
return true;
}
std::shared_ptr<ngraph::runtime::TensorView> runtime::gpu::GPU_Backend::make_primary_tensor_view( std::shared_ptr<ngraph::runtime::TensorView> runtime::gpu::GPU_Backend::make_primary_tensor_view(
const ngraph::element::Type& element_type, const Shape& shape, void* memory_pointer) const ngraph::element::Type& element_type, const Shape& shape, void* memory_pointer)
{ {
......
...@@ -16,6 +16,9 @@ ...@@ -16,6 +16,9 @@
#pragma once #pragma once
#include <map>
#include <memory>
#include "ngraph/runtime/backend.hpp" #include "ngraph/runtime/backend.hpp"
namespace ngraph namespace ngraph
...@@ -26,6 +29,9 @@ namespace ngraph ...@@ -26,6 +29,9 @@ namespace ngraph
{ {
static size_t alignment = 64; static size_t alignment = 64;
class GPU_ExternalFunction;
class GPU_CallFrame;
class GPU_Backend : public Backend class GPU_Backend : public Backend
{ {
public: public:
...@@ -41,6 +47,30 @@ namespace ngraph ...@@ -41,6 +47,30 @@ namespace ngraph
make_primary_tensor_view(const ngraph::element::Type& element_type, make_primary_tensor_view(const ngraph::element::Type& element_type,
const Shape& shape, const Shape& shape,
void* memory_pointer) override; void* memory_pointer) override;
std::shared_ptr<ngraph::runtime::TensorView>
create_tensor(const ngraph::element::Type& element_type,
const Shape& shape) override;
bool compile(const ngraph::Function& fun) override;
bool call(const std::vector<std::shared_ptr<runtime::TensorView>>& outputs,
const std::vector<std::shared_ptr<runtime::TensorView>>& inputs) override;
bool call(const ngraph::Function& fun,
const std::vector<std::shared_ptr<runtime::TensorView>>& outputs,
const std::vector<std::shared_ptr<runtime::TensorView>>& inputs) override;
private:
class FunctionInstance
{
public:
std::shared_ptr<GPU_ExternalFunction> m_external_function;
std::shared_ptr<GPU_CallFrame> m_call_frame;
std::shared_ptr<Function> m_function;
};
std::map<const Function*, FunctionInstance> m_function_map;
}; };
} }
} }
......
...@@ -18,20 +18,26 @@ ...@@ -18,20 +18,26 @@
#include "ngraph/runtime/gpu/gpu_backend.hpp" #include "ngraph/runtime/gpu/gpu_backend.hpp"
#include "ngraph/runtime/gpu/gpu_external_function.hpp" #include "ngraph/runtime/gpu/gpu_external_function.hpp"
using namespace ngraph::runtime::gpu; using namespace ngraph;
std::shared_ptr<ngraph::runtime::Backend> GPU_Manager::allocate_backend() std::shared_ptr<ngraph::runtime::Backend> runtime::gpu::GPU_Manager::allocate_backend()
{ {
return std::make_shared<GPU_Backend>(); return std::make_shared<GPU_Backend>();
} }
std::vector<size_t> runtime::gpu::GPU_Manager::get_subdevices() const
{
throw std::runtime_error("Unimplemented method");
}
std::shared_ptr<ngraph::runtime::ExternalFunction> std::shared_ptr<ngraph::runtime::ExternalFunction>
GPU_Manager::compile(const std::shared_ptr<ngraph::Function>& fun) runtime::gpu::GPU_Manager::compile(const std::shared_ptr<ngraph::Function>& fun)
{ {
return std::make_shared<GPU_ExternalFunction>(fun); return std::make_shared<GPU_ExternalFunction>(fun);
} }
ngraph::runtime::Manager::Factory GPU_Manager::factory = ngraph::runtime::Manager::register_factory( ngraph::runtime::Manager::Factory runtime::gpu::GPU_Manager::factory =
"GPU", [](const std::string& name) -> std::shared_ptr<ngraph::runtime::Manager> { ngraph::runtime::Manager::register_factory(
return std::make_shared<GPU_Manager>(); "GPU", [](const std::string& name) -> std::shared_ptr<ngraph::runtime::Manager> {
}); return std::make_shared<GPU_Manager>();
});
...@@ -16,6 +16,8 @@ ...@@ -16,6 +16,8 @@
#pragma once #pragma once
#include <memory>
#include "ngraph/runtime/manager.hpp" #include "ngraph/runtime/manager.hpp"
namespace ngraph namespace ngraph
...@@ -29,6 +31,8 @@ namespace ngraph ...@@ -29,6 +31,8 @@ namespace ngraph
public: public:
virtual std::shared_ptr<Backend> allocate_backend() override; virtual std::shared_ptr<Backend> allocate_backend() override;
virtual std::vector<size_t> get_subdevices() const override;
virtual std::shared_ptr<ngraph::runtime::ExternalFunction> virtual std::shared_ptr<ngraph::runtime::ExternalFunction>
compile(const std::shared_ptr<ngraph::Function>& fun) override; compile(const std::shared_ptr<ngraph::Function>& fun) override;
......
...@@ -15,15 +15,17 @@ ...@@ -15,15 +15,17 @@
*******************************************************************************/ *******************************************************************************/
#include "ngraph/runtime/interpreter/int_backend.hpp" #include "ngraph/runtime/interpreter/int_backend.hpp"
#include "ngraph/log.hpp" #include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/external_function.hpp" #include "ngraph/runtime/external_function.hpp"
#include "ngraph/runtime/host_tensor_view.hpp" #include "ngraph/runtime/host_tensor_view.hpp"
#include "ngraph/runtime/interpreter/int_call_frame.hpp"
#include "ngraph/runtime/interpreter/int_external_function.hpp"
using namespace ngraph; using namespace ngraph;
using namespace std; using namespace std;
shared_ptr<runtime::CallFrame> runtime::interpreter::INT_Backend::make_call_frame( shared_ptr<runtime::CallFrame> runtime::interpreter::INT_Backend::make_call_frame(
const shared_ptr<ExternalFunction>& external_function) const shared_ptr<runtime::ExternalFunction>& external_function)
{ {
return external_function->make_call_frame(); return external_function->make_call_frame();
} }
...@@ -42,3 +44,61 @@ shared_ptr<runtime::TensorView> runtime::interpreter::INT_Backend::make_primary_ ...@@ -42,3 +44,61 @@ shared_ptr<runtime::TensorView> runtime::interpreter::INT_Backend::make_primary_
auto rc = make_shared<runtime::HostTensorView>(element_type, shape, memory_pointer, "external"); auto rc = make_shared<runtime::HostTensorView>(element_type, shape, memory_pointer, "external");
return static_pointer_cast<runtime::TensorView>(rc); return static_pointer_cast<runtime::TensorView>(rc);
} }
shared_ptr<ngraph::runtime::TensorView>
runtime::interpreter::INT_Backend::create_tensor(const ngraph::element::Type& element_type,
const Shape& shape)
{
auto rc = make_shared<runtime::HostTensorView>(element_type, shape, "external");
return static_pointer_cast<runtime::TensorView>(rc);
}
bool runtime::interpreter::INT_Backend::compile(const ngraph::Function& func)
{
if (!contains_key(m_function_map, &func))
{
FunctionInstance instance;
instance.m_function = clone_function(func);
instance.m_external_function =
make_shared<interpreter::ExternalFunction>(instance.m_function);
auto cf = instance.m_external_function->make_call_frame();
instance.m_call_frame = dynamic_pointer_cast<interpreter::INT_CallFrame>(cf);
m_function_map.insert({&func, instance});
}
return true;
}
bool runtime::interpreter::INT_Backend::call(const Function& fun,
const vector<shared_ptr<runtime::TensorView>>& outputs,
const vector<shared_ptr<runtime::TensorView>>& inputs)
{
bool rc = true;
auto it = m_function_map.find(&fun);
if (it == m_function_map.end())
{
compile(fun);
it = m_function_map.find(&fun);
}
if (it == m_function_map.end())
{
throw runtime_error("Error constructing backend.");
}
FunctionInstance& instance = it->second;
instance.m_call_frame->call(outputs, inputs);
return rc;
}
bool runtime::interpreter::INT_Backend::call(const vector<shared_ptr<runtime::TensorView>>& outputs,
const vector<shared_ptr<runtime::TensorView>>& inputs)
{
if (m_function_map.size() != 1)
{
throw runtime_error("This call method only works if a single function is compiled");
}
FunctionInstance& instance = m_function_map.begin()->second;
instance.m_call_frame->call(outputs, inputs);
return true;
}
...@@ -16,14 +16,22 @@ ...@@ -16,14 +16,22 @@
#pragma once #pragma once
#include <map>
#include <memory>
#include "ngraph/runtime/backend.hpp" #include "ngraph/runtime/backend.hpp"
namespace ngraph namespace ngraph
{ {
namespace runtime namespace runtime
{ {
class CallFrame;
namespace interpreter namespace interpreter
{ {
class ExternalFunction;
class INT_CallFrame;
class INT_Backend : public runtime::Backend class INT_Backend : public runtime::Backend
{ {
public: public:
...@@ -39,6 +47,30 @@ namespace ngraph ...@@ -39,6 +47,30 @@ namespace ngraph
make_primary_tensor_view(const ngraph::element::Type& element_type, make_primary_tensor_view(const ngraph::element::Type& element_type,
const Shape& shape, const Shape& shape,
void* memory_pointer) override; void* memory_pointer) override;
std::shared_ptr<ngraph::runtime::TensorView>
create_tensor(const ngraph::element::Type& element_type,
const Shape& shape) override;
bool compile(const ngraph::Function& fun) override;
bool call(const std::vector<std::shared_ptr<runtime::TensorView>>& outputs,
const std::vector<std::shared_ptr<runtime::TensorView>>& inputs) override;
bool call(const ngraph::Function& fun,
const std::vector<std::shared_ptr<runtime::TensorView>>& outputs,
const std::vector<std::shared_ptr<runtime::TensorView>>& inputs) override;
private:
class FunctionInstance
{
public:
std::shared_ptr<interpreter::ExternalFunction> m_external_function;
std::shared_ptr<interpreter::INT_CallFrame> m_call_frame;
std::shared_ptr<Function> m_function;
};
std::map<const Function*, FunctionInstance> m_function_map;
}; };
} }
} }
......
...@@ -25,10 +25,8 @@ ...@@ -25,10 +25,8 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
runtime::interpreter::INT_CallFrame::INT_CallFrame(shared_ptr<ExternalFunction> external_function, runtime::interpreter::INT_CallFrame::INT_CallFrame(shared_ptr<Function> func)
shared_ptr<Function> func) : m_function(func)
: m_external_function(external_function)
, m_function(func)
, m_emit_timing(std::getenv("NGRAPH_INTERPRETER_EMIT_TIMING") != nullptr) , m_emit_timing(std::getenv("NGRAPH_INTERPRETER_EMIT_TIMING") != nullptr)
, m_nan_check(std::getenv("NGRAPH_INTERPRETER_NAN_CHECK") != nullptr) , m_nan_check(std::getenv("NGRAPH_INTERPRETER_NAN_CHECK") != nullptr)
{ {
......
...@@ -119,7 +119,6 @@ namespace ngraph ...@@ -119,7 +119,6 @@ namespace ngraph
namespace interpreter namespace interpreter
{ {
class ExternalFunction;
class INT_CallFrame; class INT_CallFrame;
} }
} }
...@@ -129,8 +128,7 @@ namespace ngraph ...@@ -129,8 +128,7 @@ namespace ngraph
class ngraph::runtime::interpreter::INT_CallFrame : public runtime::CallFrame class ngraph::runtime::interpreter::INT_CallFrame : public runtime::CallFrame
{ {
public: public:
INT_CallFrame(std::shared_ptr<ExternalFunction> external_function, INT_CallFrame(std::shared_ptr<Function> func);
std::shared_ptr<Function> func);
/// @brief Invoke the function with values matching the signature of the function. /// @brief Invoke the function with values matching the signature of the function.
/// ///
...@@ -155,7 +153,6 @@ private: ...@@ -155,7 +153,6 @@ private:
static void perform_nan_check(const std::vector<std::shared_ptr<HostTensorView>>&, static void perform_nan_check(const std::vector<std::shared_ptr<HostTensorView>>&,
const Node* op = nullptr); const Node* op = nullptr);
std::shared_ptr<ExternalFunction> m_external_function;
std::shared_ptr<Function> m_function; std::shared_ptr<Function> m_function;
bool m_emit_timing; bool m_emit_timing;
bool m_nan_check; bool m_nan_check;
......
...@@ -106,9 +106,6 @@ void runtime::interpreter::ExternalFunction::compile() ...@@ -106,9 +106,6 @@ void runtime::interpreter::ExternalFunction::compile()
return; return;
} }
string function_name = m_interpreter_function->get_name();
string dump_filename = file_util::path_join(s_output_dir, function_name + "_ops.txt");
pass::Manager pass_manager; pass::Manager pass_manager;
// For now, just make everyone row-major. // For now, just make everyone row-major.
pass_manager.register_pass<pass::AssignLayout<DenseTensorViewLayout>>(); pass_manager.register_pass<pass::AssignLayout<DenseTensorViewLayout>>();
...@@ -129,6 +126,5 @@ shared_ptr<runtime::CallFrame> runtime::interpreter::ExternalFunction::make_call ...@@ -129,6 +126,5 @@ shared_ptr<runtime::CallFrame> runtime::interpreter::ExternalFunction::make_call
compile(); compile();
} }
return make_shared<runtime::interpreter::INT_CallFrame>(shared_from_this(), return make_shared<runtime::interpreter::INT_CallFrame>(m_function);
m_interpreter_function);
} }
...@@ -28,12 +28,11 @@ namespace ngraph ...@@ -28,12 +28,11 @@ namespace ngraph
{ {
namespace interpreter namespace interpreter
{ {
class ExternalFunction : public ngraph::runtime::ExternalFunction, class ExternalFunction : public ngraph::runtime::ExternalFunction
public std::enable_shared_from_this<ExternalFunction>
{ {
public: public:
ExternalFunction(const std::shared_ptr<ngraph::Function>& function, ExternalFunction(const std::shared_ptr<ngraph::Function>& function,
bool release_function = true); bool release_function = false);
std::shared_ptr<ngraph::runtime::CallFrame> make_call_frame(); std::shared_ptr<ngraph::runtime::CallFrame> make_call_frame();
protected: protected:
......
...@@ -28,6 +28,12 @@ shared_ptr<runtime::Backend> runtime::interpreter::INT_Manager::allocate_backend ...@@ -28,6 +28,12 @@ shared_ptr<runtime::Backend> runtime::interpreter::INT_Manager::allocate_backend
return make_shared<INT_Backend>(); return make_shared<INT_Backend>();
} }
std::vector<size_t> runtime::interpreter::INT_Manager::get_subdevices() const
{
vector<size_t> rc;
return rc;
}
shared_ptr<runtime::ExternalFunction> shared_ptr<runtime::ExternalFunction>
runtime::interpreter::INT_Manager::compile(const shared_ptr<Function>& fun) runtime::interpreter::INT_Manager::compile(const shared_ptr<Function>& fun)
{ {
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#pragma once #pragma once
#include <memory> #include <memory>
#include <vector>
#include "ngraph/runtime/manager.hpp" #include "ngraph/runtime/manager.hpp"
...@@ -36,6 +37,8 @@ namespace ngraph ...@@ -36,6 +37,8 @@ namespace ngraph
public: public:
virtual std::shared_ptr<Backend> allocate_backend() override; virtual std::shared_ptr<Backend> allocate_backend() override;
virtual std::vector<size_t> get_subdevices() const override;
virtual std::shared_ptr<ngraph::runtime::ExternalFunction> virtual std::shared_ptr<ngraph::runtime::ExternalFunction>
compile(const std::shared_ptr<ngraph::Function>& fun) override; compile(const std::shared_ptr<ngraph::Function>& fun) override;
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <vector>
#include "ngraph/except.hpp" #include "ngraph/except.hpp"
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include <map> #include <map>
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector>
namespace ngraph namespace ngraph
{ {
...@@ -36,13 +37,21 @@ namespace ngraph ...@@ -36,13 +37,21 @@ namespace ngraph
/// a backed for execution and allocation. /// a backed for execution and allocation.
class Manager class Manager
{ {
friend class runtime::Backend;
public: public:
virtual ~Manager() {} virtual ~Manager() {}
/// DEPRECATED
/// @brief Allocate a backend for this transformer. /// @brief Allocate a backend for this transformer.
/// ///
/// Specific transformers may provide addtional methods for allocating customized backends. /// Specific transformers may provide addtional methods for allocating customized backends.
virtual std::shared_ptr<Backend> allocate_backend() = 0; virtual std::shared_ptr<Backend> allocate_backend() = 0;
/// @brief Query the list of available subdevices of this device.
/// @returns A vector of available devices of the specified type.
virtual std::vector<size_t> get_subdevices() const = 0;
/// DEPRECATED
/// @brief Convert a function to a form that can be run on a backend. /// @brief Convert a function to a form that can be run on a backend.
virtual std::shared_ptr<ExternalFunction> virtual std::shared_ptr<ExternalFunction>
compile(const std::shared_ptr<ngraph::Function>& fun) = 0; compile(const std::shared_ptr<ngraph::Function>& fun) = 0;
......
...@@ -25,6 +25,7 @@ include_directories( ...@@ -25,6 +25,7 @@ include_directories(
) )
set (SRC set (SRC
backend_api.cpp
backend_debug_api.cpp backend_debug_api.cpp
builder.cpp builder.cpp
builder_autobroadcast.cpp builder_autobroadcast.cpp
......
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/runtime/backend.hpp"
#include "ngraph/util.hpp"
using namespace std;
using namespace ngraph;
TEST(backend_api, registered_devices)
{
vector<string> devices = runtime::Backend::get_registered_devices();
EXPECT_GE(devices.size(), 1);
EXPECT_TRUE(contains(devices, "INTERPRETER"));
}
This diff is collapsed.
...@@ -211,8 +211,9 @@ TEST(copy, FunctionCall) ...@@ -211,8 +211,9 @@ TEST(copy, FunctionCall)
ASSERT_NE(node_cast, nullptr); ASSERT_NE(node_cast, nullptr);
ASSERT_TRUE(nullptr != new_node); ASSERT_TRUE(nullptr != new_node);
ASSERT_TRUE(new_args == new_node->get_input_ops()); EXPECT_TRUE(new_args == new_node->get_input_ops());
ASSERT_TRUE(node_cast->get_functions()[0] == f); ASSERT_EQ(node_cast->get_functions().size(), 1);
EXPECT_NE(f, node_cast->get_functions()[0]);
} }
TEST(copy, greater_eq) TEST(copy, greater_eq)
...@@ -303,9 +304,10 @@ TEST(copy, reduce) ...@@ -303,9 +304,10 @@ TEST(copy, reduce)
ASSERT_NE(node_cast, nullptr); ASSERT_NE(node_cast, nullptr);
ASSERT_TRUE(nullptr != new_node); ASSERT_TRUE(nullptr != new_node);
ASSERT_TRUE(new_args == new_node->get_input_ops()); EXPECT_TRUE(new_args == new_node->get_input_ops());
ASSERT_TRUE(f == node_cast->get_functions()[0]); ASSERT_EQ(node_cast->get_functions().size(), 1);
ASSERT_TRUE(axes == node_cast->get_reduction_axes()); EXPECT_NE(f, node_cast->get_functions()[0]);
EXPECT_TRUE(axes == node_cast->get_reduction_axes());
} }
TEST(copy, remainder) TEST(copy, remainder)
......
...@@ -278,7 +278,7 @@ TEST_F(CloneTest, clone_nodes_partial) ...@@ -278,7 +278,7 @@ TEST_F(CloneTest, clone_nodes_partial)
TEST_F(CloneTest, clone_function_full) TEST_F(CloneTest, clone_function_full)
{ {
auto cloned_func = clone_function(func, node_map); auto cloned_func = clone_function(*func, node_map);
ASSERT_TRUE(CompareNodeVector(func->get_ops(), cloned_func->get_ops(), node_map)); ASSERT_TRUE(CompareNodeVector(func->get_ops(), cloned_func->get_ops(), node_map));
} }
...@@ -294,8 +294,7 @@ TEST(graph_util, clone_multiple_results) ...@@ -294,8 +294,7 @@ TEST(graph_util, clone_multiple_results)
auto f = auto f =
make_shared<Function>(NodeVector{A_add_B, A_add_B_mul_C}, op::ParameterVector{A, B, C}); make_shared<Function>(NodeVector{A_add_B, A_add_B_mul_C}, op::ParameterVector{A, B, C});
NodeMap node_map; auto copy = clone_function(*f);
auto copy = clone_function(f, node_map);
} }
TEST(util, round_up) TEST(util, round_up)
......
...@@ -191,15 +191,13 @@ namespace ngraph ...@@ -191,15 +191,13 @@ namespace ngraph
} }
// compile and run modified (y, cached) = f(x) // compile and run modified (y, cached) = f(x)
NodeMap nm1; auto clone_fwd = clone_function(*fprop_cache.fprop);
auto clone_fwd = clone_function(fprop_cache.fprop, nm1);
auto cache_fwd = manager->compile(clone_fwd); auto cache_fwd = manager->compile(clone_fwd);
auto cache_fwd_cf = backend->make_call_frame(cache_fwd); auto cache_fwd_cf = backend->make_call_frame(cache_fwd);
cache_fwd_cf->tensor_call(mod_f_output_args, f_input_args); cache_fwd_cf->tensor_call(mod_f_output_args, f_input_args);
// call modfied f'(c, cached) to get df/dX* // call modfied f'(c, cached) to get df/dX*
NodeMap nm2; auto clone_bwd = clone_function(*fprop_cache.bprop);
auto clone_bwd = clone_function(fprop_cache.bprop, nm2);
auto cache_dfdx = auto cache_dfdx =
get_autodiff<T>(manager, backend, clone_bwd, mod_df_input_args, indep_params); get_autodiff<T>(manager, backend, clone_bwd, mod_df_input_args, indep_params);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment