Unverified Commit 7ad4c0ab authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Add save/load API to runtime (#2955)

* API defined

* add unit test for save/load with INTERPRETER

* Update per review comments

* fix compiler error
parent 4ec94acc
......@@ -263,8 +263,9 @@ const vector<cpio::FileInfo>& cpio::Reader::get_file_info()
return m_file_info;
}
void cpio::Reader::read(const string& file_name, void* data, size_t size_in_bytes)
bool cpio::Reader::read(const string& file_name, void* data, size_t size_in_bytes)
{
bool rc = false;
for (const FileInfo& info : get_file_info())
{
if (info.get_name() == file_name)
......@@ -275,9 +276,18 @@ void cpio::Reader::read(const string& file_name, void* data, size_t size_in_byte
}
m_stream->seekg(info.get_offset(), ios_base::beg);
m_stream->read(reinterpret_cast<char*>(data), size_in_bytes);
rc = true;
break;
}
}
return rc;
}
vector<char> cpio::Reader::read(const FileInfo& info)
{
vector<char> buffer(info.get_size());
read(info.get_name(), buffer.data(), info.get_size());
return buffer;
}
bool cpio::is_cpio(const string& path)
......
......@@ -107,7 +107,8 @@ public:
void open(const std::string& filename);
void close();
const std::vector<FileInfo>& get_file_info();
void read(const std::string& file_name, void* data, size_t size_in_bytes);
bool read(const std::string& file_name, void* data, size_t size_in_bytes);
std::vector<char> read(const FileInfo& info);
private:
std::istream* m_stream;
......
......@@ -85,3 +85,8 @@ bool runtime::Backend::is_supported_property(const Property prop) const
void runtime::Backend::remove_compiled_function(std::shared_ptr<Executable> exec)
{
}
std::shared_ptr<runtime::Executable> runtime::Backend::load(istream& input_stream)
{
throw runtime_error("load opertion unimplemented.");
}
......@@ -109,6 +109,11 @@ public:
ngraph::pass::PassConfig& pass_config,
bool enable_performance_data = false);
/// \brief Loads a previously saved Executable object from a stream.
/// \param input_stream the opened input stream containing the saved Executable
/// \returns A compiled function or throws an exception on error
virtual std::shared_ptr<Executable> load(std::istream& input_stream);
/// \brief Test if a backend is capable of supporting an op
/// \param node is the op to test.
/// \returns true if the op is supported, false otherwise.
......
......@@ -118,3 +118,8 @@ vector<runtime::PerformanceCounter> runtime::Executable::get_performance_data()
{
return vector<PerformanceCounter>();
}
void runtime::Executable::save(std::ostream& output_stream)
{
throw runtime_error("save opertion unimplemented.");
}
......@@ -69,6 +69,10 @@ public:
/// \returns an ngraph::ResultVector of all input parameters
const ngraph::ResultVector& get_results() const;
/// \brief Save this compiled Executable to an output stream.
/// Saved stream may be read with Backend::load
virtual void save(std::ostream& output_stream);
protected:
/// \brief Called at the end of compile to the values to be returned by get_parameters
/// and get_results
......
......@@ -15,10 +15,12 @@
//*****************************************************************************
#include "ngraph/runtime/interpreter/int_backend.hpp"
#include "ngraph/cpio.hpp"
#include "ngraph/except.hpp"
#include "ngraph/runtime/backend_manager.hpp"
#include "ngraph/runtime/host_tensor.hpp"
#include "ngraph/runtime/interpreter/int_executable.hpp"
#include "ngraph/serializer.hpp"
#include "ngraph/util.hpp"
using namespace std;
......@@ -72,3 +74,34 @@ bool runtime::interpreter::INTBackend::is_supported(const Node& node) const
{
return m_unsupported_op_name_list.find(node.description()) == m_unsupported_op_name_list.end();
}
std::shared_ptr<runtime::Executable> runtime::interpreter::INTBackend::load(istream& in)
{
shared_ptr<Executable> exec;
cpio::Reader reader(in);
auto file_info = reader.get_file_info();
string save_info;
for (const cpio::FileInfo& info : file_info)
{
if (info.get_name() == "save_info")
{
vector<char> buffer = reader.read(info);
save_info = string(buffer.data(), buffer.size());
break;
}
}
if (save_info == "INTERPRETER Save File 1.0")
{
for (const cpio::FileInfo& info : file_info)
{
if (info.get_name() == "model")
{
vector<char> buffer = reader.read(info);
string model_string = string(buffer.data(), buffer.size());
exec = shared_ptr<INTExecutable>(new INTExecutable(model_string));
break;
}
}
}
return exec;
}
......@@ -54,6 +54,7 @@ public:
std::shared_ptr<Executable> compile(std::shared_ptr<Function> function,
bool enable_performance_data = false) override;
std::shared_ptr<Executable> load(std::istream& input_stream) override;
bool is_supported(const Node& node) const override;
......
......@@ -15,6 +15,7 @@
//*****************************************************************************
#include "ngraph/runtime/interpreter/int_executable.hpp"
#include "ngraph/cpio.hpp"
#include "ngraph/descriptor/layout/dense_tensor_layout.hpp"
#include "ngraph/except.hpp"
#include "ngraph/op/convert.hpp"
......@@ -28,6 +29,7 @@
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/memory_layout.hpp"
#include "ngraph/runtime/backend_manager.hpp"
#include "ngraph/serializer.hpp"
#include "ngraph/util.hpp"
using namespace std;
......@@ -37,21 +39,34 @@ using descriptor::layout::DenseTensorLayout;
runtime::interpreter::INTExecutable::INTExecutable(const shared_ptr<Function>& function,
bool enable_performance_collection)
: m_performance_counters_enabled{enable_performance_collection}
: m_is_compiled{true}
, m_performance_counters_enabled{enable_performance_collection}
{
m_is_compiled = true;
m_function = clone_function(*function);
pass::Manager pass_manager;
pass_manager.register_pass<pass::LikeReplacement>();
pass_manager.register_pass<pass::FusedOpDecomposition>();
pass_manager.register_pass<pass::AssignLayout<DenseTensorLayout>>();
pass_manager.register_pass<pass::Liveness>();
pass_manager.run_passes(function);
pass_manager.run_passes(m_function);
for (const shared_ptr<Node>& node : function->get_ordered_ops())
for (const shared_ptr<Node>& node : m_function->get_ordered_ops())
{
m_wrapped_nodes.emplace_back(node);
}
set_parameters_and_results(*function);
set_parameters_and_results(*m_function);
}
runtime::interpreter::INTExecutable::INTExecutable(const std::string& model_string)
: m_is_compiled{true}
, m_performance_counters_enabled{false}
{
m_function = deserialize(model_string);
for (const shared_ptr<Node>& node : m_function->get_ordered_ops())
{
m_wrapped_nodes.emplace_back(node);
}
set_parameters_and_results(*m_function);
}
bool runtime::interpreter::INTExecutable::call(const vector<shared_ptr<runtime::Tensor>>& outputs,
......@@ -278,3 +293,12 @@ void runtime::interpreter::INTExecutable::perform_nan_check(
arg_number++;
}
}
void runtime::interpreter::INTExecutable::save(ostream& out)
{
cpio::Writer writer(out);
string si = "INTERPRETER Save File 1.0";
writer.write("save_info", si.data(), si.size());
string model = serialize(m_function, 0);
writer.write("model", model.data(), model.size());
}
......@@ -160,6 +160,8 @@ namespace ngraph
class ngraph::runtime::interpreter::INTExecutable : public Executable
{
friend class INTBackend;
public:
INTExecutable(const std::shared_ptr<Function>& function,
bool enable_performance_collection = false);
......@@ -167,15 +169,20 @@ public:
bool call(const std::vector<std::shared_ptr<Tensor>>& outputs,
const std::vector<std::shared_ptr<Tensor>>& intputs) override;
virtual void save(std::ostream& output_stream) override;
void set_nan_check(bool enable);
std::vector<PerformanceCounter> get_performance_data() const override;
private:
INTExecutable(const std::string& model_string);
int get_alignment() const { return 64; }
bool m_is_compiled = false;
bool m_nan_check_enabled = false;
bool m_performance_counters_enabled = false;
std::shared_ptr<Function> m_function;
std::unordered_map<std::shared_ptr<const Node>, stopwatch> m_timer_map;
std::vector<NodeWrapper> m_wrapped_nodes;
std::unordered_map<const Node*, std::shared_ptr<RNGState>> m_states;
......
......@@ -18,6 +18,8 @@
#include "ngraph/ngraph.hpp"
#include "ngraph/runtime/backend.hpp"
#include "ngraph/util.hpp"
#include "util/all_close_f.hpp"
#include "util/test_tools.hpp"
using namespace std;
using namespace ngraph;
......@@ -34,3 +36,34 @@ TEST(backend_api, invalid_name)
{
ASSERT_ANY_THROW(ngraph::runtime::Backend::create("COMPLETELY-BOGUS-NAME"));
}
TEST(backend_api, save_load)
{
Shape shape{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::Parameter>(element::f32, shape);
auto f = make_shared<Function>(make_shared<op::Add>(A, B), ParameterVector{A, B});
auto backend = runtime::Backend::create("INTERPRETER");
// Create some tensors for input/output
shared_ptr<runtime::Tensor> a = backend->create_tensor(element::f32, shape);
shared_ptr<runtime::Tensor> b = backend->create_tensor(element::f32, shape);
shared_ptr<runtime::Tensor> result = backend->create_tensor(element::f32, shape);
copy_data<float>(a, {1.f, 2.f, 3.f, 4.f});
copy_data<float>(b, {5.f, 6.f, 7.f, 8.f});
{
ofstream file("test.interpreter_save");
auto handle = backend->compile(f);
handle->save(file);
}
{
ifstream file("test.interpreter_save");
auto handle = backend->load(file);
ASSERT_NE(handle, nullptr);
handle->call_with_validate({result}, {a, b});
EXPECT_TRUE(test::all_close_f(read_vector<float>(result), {6.f, 8.f, 10.f, 12.f}));
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment