Unverified Commit a5d3c78d authored by Artur Wojcik's avatar Artur Wojcik Committed by GitHub

[ONNX] Support for external weights to enable Caffe2 models (#1941)

* onnx: enable external weights to enable Caffe2 support
Signed-off-by: 's avatarArtur Wojcik <artur.wojcik@intel.com>

* onnx: update ONNX importer interface documentation
Signed-off-by: 's avatarArtur Wojcik <artur.wojcik@intel.com>

* onnx: after review updates
Signed-off-by: 's avatarArtur Wojcik <artur.wojcik@intel.com>
parent d37cafd9
......@@ -41,16 +41,21 @@ static std::shared_ptr<ngraph::Function> import_onnx_function(const std::string&
return ngraph::onnx_import::import_onnx_function(iss);
}
static std::vector<std::shared_ptr<ngraph::Function>>
load_onnx_model_file(const std::string& filename)
{
return ngraph::onnx_import::load_onnx_model(filename);
}
static std::shared_ptr<ngraph::Function> import_onnx_function_file(const std::string& filename)
{
return ngraph::onnx_import::import_onnx_function(filename);
}
void regmodule_pyngraph_onnx_import(py::module mod)
{
mod.def("load_onnx_model", &load_onnx_model);
mod.def("import_onnx_function", &import_onnx_function);
mod.def("load_onnx_model_file",
static_cast<std::vector<std::shared_ptr<ngraph::Function>> (*)(const std::string&)>(
&ngraph::onnx_import::load_onnx_model),
py::arg());
mod.def("import_onnx_function_file",
static_cast<std::shared_ptr<ngraph::Function> (*)(const std::string&)>(
&ngraph::onnx_import::import_onnx_function),
py::arg());
mod.def("load_onnx_model_file", &load_onnx_model_file);
mod.def("import_onnx_function_file", &import_onnx_function_file);
}
......@@ -21,7 +21,8 @@ add_library(onnx_import_interface OBJECT
onnx.hpp
core/operator_set.hpp
core/node.cpp
core/node.hpp)
core/node.hpp
core/weight.hpp)
add_library(onnx_import STATIC
core/attribute.cpp
......@@ -149,6 +150,8 @@ add_library(onnx_import STATIC
utils/reshape.hpp
utils/variadic.hpp)
set(ONNX_IMPORT_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR} CACHE INTERNAL "")
add_dependencies(onnx_import onnx_import_interface)
if (NOT NGRAPH_USE_SYSTEM_PROTOBUF)
......@@ -157,14 +160,14 @@ if (NOT NGRAPH_USE_SYSTEM_PROTOBUF)
endif()
set_property(TARGET onnx_import PROPERTY POSITION_INDEPENDENT_CODE ON)
target_include_directories(onnx_import PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} ${NGRAPH_INCLUDE_PATH}
target_include_directories(onnx_import PRIVATE ${ONNX_IMPORT_INCLUDE_DIR} ${NGRAPH_INCLUDE_PATH}
SYSTEM PRIVATE ${ONNX_INCLUDE_DIR} ${ONNX_PROTO_INCLUDE_DIR} ${Protobuf_INCLUDE_DIR})
target_link_libraries(onnx_import PRIVATE ${Protobuf_LIBRARIES} ${ONNX_PROTO_LIBRARY})
target_compile_definitions(onnx_import PRIVATE ONNX_OPSET_VERSION=${ONNX_OPSET_VERSION})
set_property(TARGET onnx_import_interface PROPERTY POSITION_INDEPENDENT_CODE ON)
target_include_directories(onnx_import_interface PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${NGRAPH_INCLUDE_PATH}
target_include_directories(onnx_import_interface PRIVATE ${ONNX_IMPORT_INCLUDE_DIR} ${NGRAPH_INCLUDE_PATH}
SYSTEM PRIVATE ${ONNX_INCLUDE_DIR} ${ONNX_PROTO_INCLUDE_DIR} ${Protobuf_INCLUDE_DIR})
target_compile_definitions(onnx_import_interface PRIVATE ONNX_OPSET_VERSION=${ONNX_OPSET_VERSION})
......@@ -172,8 +175,8 @@ target_compile_definitions(onnx_import_interface PRIVATE ONNX_OPSET_VERSION=${ON
if ("${CMAKE_CXX_COMPILER_ID}" MATCHES "^(Apple)?Clang$")
target_compile_options(onnx_import PRIVATE -Wno-undef -Wno-reserved-id-macro -Wno-switch-enum
-Wno-extended-offsetof -Wno-shorten-64-to-32 -Wno-unused-macros -Wno-missing-variable-declarations
-Wno-unused-private-field -Wno-shadow -Wno-deprecated)
-Wno-unused-private-field -Wno-shadow -Wno-deprecated PUBLIC -Wno-undefined-func-template)
target_compile_options(onnx_import_interface PRIVATE -Wno-undef -Wno-reserved-id-macro -Wno-switch-enum
-Wno-extended-offsetof -Wno-shorten-64-to-32 -Wno-unused-macros -Wno-missing-variable-declarations
-Wno-unused-private-field -Wno-shadow -Wno-deprecated)
-Wno-unused-private-field -Wno-shadow -Wno-deprecated PUBLIC -Wno-undefined-func-template)
endif()
......@@ -42,7 +42,9 @@ namespace ngraph
}
}
Graph::Graph(const onnx::GraphProto& graph_proto, const Model& model)
Graph::Graph(const onnx::GraphProto& graph_proto,
const Model& model,
const Weights& weights)
: m_graph_proto{&graph_proto}
, m_model{&model}
{
......@@ -59,7 +61,7 @@ namespace ngraph
{
m_inputs.emplace_back(input);
m_ng_node_cache[input.name()] =
m_inputs.back().get_ng_node(m_parameters, m_initializers);
m_inputs.back().get_ng_node(m_parameters, m_initializers, weights);
}
for (const auto& output : m_graph_proto->output())
......
......@@ -26,6 +26,7 @@
#include "model.hpp"
#include "operator_set.hpp"
#include "value_info.hpp"
#include "weight.hpp"
namespace ngraph
{
......@@ -34,7 +35,7 @@ namespace ngraph
class Graph
{
public:
Graph(const onnx::GraphProto& proto, const Model& model);
Graph(const onnx::GraphProto& proto, const Model& model, const Weights& weights = {});
const std::vector<Node>& get_nodes() const { return m_nodes; }
const std::vector<ValueInfo>& get_inputs() const { return m_inputs; }
......
......@@ -25,6 +25,7 @@
#include "node.hpp"
#include "tensor.hpp"
#include "weight.hpp"
namespace ngraph
{
......@@ -104,7 +105,8 @@ namespace ngraph
std::shared_ptr<ngraph::Node>
get_ng_node(op::ParameterVector& parameters,
const std::map<std::string, Tensor>& initializers) const
const std::map<std::string, Tensor>& initializers,
const Weights& weights = {}) const
{
const auto it = initializers.find(get_name());
if (it != std::end(initializers))
......@@ -113,9 +115,14 @@ namespace ngraph
}
else
{
parameters.push_back(get_ng_parameter());
return parameters.back();
const auto pt = weights.find(get_name());
if (pt != std::end(weights))
{
return get_ng_constant(pt->second);
}
}
parameters.push_back(get_ng_parameter());
return parameters.back();
}
protected:
......@@ -124,6 +131,11 @@ namespace ngraph
return std::make_shared<op::Parameter>(get_element_type(), get_shape());
}
std::shared_ptr<op::Constant> get_ng_constant(const Weight& weight) const
{
return std::make_shared<op::Constant>(weight.type(), weight.shape(), weight.data());
}
std::shared_ptr<op::Constant> get_ng_constant(const Tensor& tensor) const
{
switch (m_value_info_proto->type().tensor_type().elem_type())
......
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <string>
#include <unordered_map>
#include <vector>
#include "ngraph/runtime/backend.hpp"
#include "ngraph/runtime/tensor.hpp"
namespace ngraph
{
namespace onnx_import
{
/// \brief Weight for an input
class Weight
{
public:
Weight(const Weight&) = default;
Weight& operator=(const Weight&) = delete;
Weight() = delete;
Weight(Weight&&) = default;
Weight& operator=(Weight&&) = delete;
Weight(const element::Type& type, const Shape& shape, std::vector<char> data)
: m_shape{shape}
, m_type{type}
, m_data{std::move(data)}
{
for (const auto& value : m_shape)
{
m_size *= value;
}
}
const Shape& shape() const { return m_shape; }
std::size_t size() const { return m_size; }
const element::Type& type() const { return m_type; }
std::shared_ptr<runtime::Tensor> to_tensor(runtime::Backend& backend)
{
return backend.create_tensor(
m_type, m_shape, reinterpret_cast<void*>(m_data.data()));
}
const void* data() const { return reinterpret_cast<const void*>(m_data.data()); }
private:
Shape m_shape{};
const element::Type& m_type;
std::size_t m_size{1};
std::vector<char> m_data{};
};
using Weights = std::unordered_map<std::string, Weight>;
}
}
......@@ -52,7 +52,8 @@ namespace ngraph
} // namespace error
} // namespace detail
std::vector<std::shared_ptr<Function>> load_onnx_model(std::istream& sin)
std::vector<std::shared_ptr<Function>> load_onnx_model(std::istream& sin,
const Weights& weights)
{
onnx::ModelProto model_proto;
if (!model_proto.ParseFromIstream(&sin))
......@@ -61,7 +62,7 @@ namespace ngraph
}
std::vector<std::shared_ptr<Function>> output_functions;
Model model{model_proto};
Graph graph{model_proto.graph(), model};
Graph graph{model_proto.graph(), model, weights};
for (const auto& output : graph.get_outputs())
{
output_functions.emplace_back(std::make_shared<Function>(
......@@ -70,24 +71,26 @@ namespace ngraph
return output_functions;
}
std::vector<std::shared_ptr<Function>> load_onnx_model(const std::string& path)
std::vector<std::shared_ptr<Function>> load_onnx_model(const std::string& path,
const Weights& weights)
{
std::ifstream ifs{path, std::ios::in | std::ios::binary};
if (!ifs.is_open())
{
throw detail::error::file_open{path};
}
return load_onnx_model(ifs);
return load_onnx_model(ifs, weights);
}
std::shared_ptr<Function> import_onnx_function(std::istream& sin)
std::shared_ptr<Function> import_onnx_function(std::istream& sin, const Weights& weights)
{
return load_onnx_model(sin).front();
return load_onnx_model(sin, weights).front();
}
std::shared_ptr<Function> import_onnx_function(const std::string& path)
std::shared_ptr<Function> import_onnx_function(const std::string& path,
const Weights& weights)
{
return load_onnx_model(path).front();
return load_onnx_model(path, weights).front();
}
void register_operator(const std::string& name,
......
......@@ -22,28 +22,73 @@
#include "ngraph/function.hpp"
#include "core/operator_set.hpp"
#include "core/weight.hpp"
namespace ngraph
{
namespace onnx_import
{
// Registers ONNX custom operator
/// \brief Registers ONNX custom operator
/// Performs the registration of external ONNX operator. This means the code
/// of the operator is not part of ONNX importer. The operator shall be registered
/// before calling `load_onnx_model()` or `import_onnx_function()` functions.
/// \param name name of the operator,
/// \param version version of the operator (opset),
/// \param domain domain the operator belongs to,
/// \param fn function providing the implementation of the operator.
void register_operator(const std::string& name,
std::int64_t version,
const std::string& domain,
Operator fn);
// Convert on ONNX model to a vector of nGraph Functions (input stream)
std::vector<std::shared_ptr<Function>> load_onnx_model(std::istream&);
/// \brief Convert an ONNX model to nGraph functions
/// The function translated serialized ONNX model to nGraph functions. The serialized
/// ONNX model is read from input stream.
/// \param sin input stream (e.g. file stream, memory stream, etc),
/// \param weights weights associated with the model. If weights are embedded into
/// the model this parameter shall be empty. Having weights in a model
/// and providing through this parameters is invalid (the weights from
/// the model will take precedence).
/// \return The function returns a vector of nGraph functions. The number of functions
/// depends on number of outputs from ONNX graph.
std::vector<std::shared_ptr<Function>> load_onnx_model(std::istream& sin,
const Weights& weights = {});
/// \brief Convert an ONNX model to nGraph functions
/// The function translated serialized ONNX model to nGraph functions. The ONNX model
/// is read from ONNX file.
/// \param filename file name (relative or absolute path name),
/// \param weights weights associated with the model. If weights are embedded into
/// the model this parameter shall be empty. Having weights in a model
/// and providing through this parameters is invalid (the weights from
/// the model will take precedence).
/// \return The function returns a vector of nGraph functions. The number of functions
/// depends on number of outputs from ONNX graph.
std::vector<std::shared_ptr<Function>> load_onnx_model(const std::string& filename,
const Weights& weights = {});
// Convert an ONNX model to a vector of nGraph Functions
std::vector<std::shared_ptr<Function>> load_onnx_model(const std::string&);
/// \brief Convert an ONNX model to nGraph function
/// The function translated serialized ONNX model to nGraph function. The serialized
/// ONNX model is read from input stream.
/// \param sin input stream (e.g. file stream, memory stream, etc),
/// \param weights weights associated with the model. If weights are embedded into
/// the model this parameter shall be empty. Having weights in a model
/// and providing through this parameters is invalid (the weights from
/// the model will take precedence).
/// \return The function returns a nGraph function representing single output from graph.
std::shared_ptr<Function> import_onnx_function(std::istream& sin,
const Weights& weights = {});
// Convert the first output of an ONNX model to an nGraph Function (input stream)
std::shared_ptr<Function> import_onnx_function(std::istream&);
// Convert the first output of an ONNX model to an nGraph Function
std::shared_ptr<Function> import_onnx_function(const std::string&);
/// \brief Convert an ONNX model to nGraph functions
/// The function translated serialized ONNX model to nGraph functions. The ONNX model
/// is read from ONNX file.
/// \param filename file name (relative or absolute path name),
/// \param weights weights associated with the model. If weights are embedded into
/// the model this parameter shall be empty. Having weights in a model
/// and providing through this parameters is invalid (the weights from
/// the model will take precedence).
/// \return The function returns a nGraph function representing single output from graph.
std::shared_ptr<Function> import_onnx_function(const std::string& filename,
const Weights& weights = {});
} // namespace onnx_import
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment