Unverified Commit b6078c37 authored by Mateusz Bencer's avatar Mateusz Bencer Committed by GitHub

Improve ONNX importer API documentation (#4386)

* Split api to onnx_utils, headers refactor

* Improved ONNX importer API doc

* styles applied

* Code review remarks introduced

* remarks introduced, style applied

* Code review remarks introduced

* namespaces description added

* Code review remarks introduced

* styles applied

* Fixed onnx_utils doc and margin alignment

* code review remarks introduced
Co-authored-by: 's avatarMichał Karzyński <postrational@users.noreply.github.com>
parent b95c25f6
......@@ -21,7 +21,9 @@ add_library(onnx_import_interface OBJECT
core/node.hpp
core/operator_set.hpp
onnx.cpp
onnx.hpp)
onnx.hpp
onnx_utils.hpp
onnx_utils.cpp)
add_library(onnx_import STATIC
core/attribute.cpp
......
......@@ -27,7 +27,10 @@ namespace ngraph
{
namespace onnx_import
{
/// \brief Function which transforms single ONNX operator to nGraph sub-graph.
using Operator = std::function<NodeVector(const Node&)>;
/// \brief Map which contains ONNX operators accessible by std::string value as a key.
using OperatorSet = std::unordered_map<std::string, std::reference_wrapper<const Operator>>;
} // namespace onnx_import
......
......@@ -52,20 +52,20 @@ namespace ngraph
} // namespace error
} // namespace detail
std::shared_ptr<Function> import_onnx_model(std::istream& sin)
std::shared_ptr<Function> import_onnx_model(std::istream& stream)
{
onnx::ModelProto model_proto;
// Try parsing input as a binary protobuf message
if (!model_proto.ParseFromIstream(&sin))
if (!model_proto.ParseFromIstream(&stream))
{
// Rewind to the beginning and clear stream state.
sin.clear();
sin.seekg(0);
google::protobuf::io::IstreamInputStream iistream(&sin);
stream.clear();
stream.seekg(0);
google::protobuf::io::IstreamInputStream iistream(&stream);
// Try parsing input as a prototxt message
if (!google::protobuf::TextFormat::Parse(&iistream, &model_proto))
{
throw detail::error::stream_parse{sin};
throw detail::error::stream_parse{stream};
}
}
......@@ -80,24 +80,16 @@ namespace ngraph
return function;
}
std::shared_ptr<Function> import_onnx_model(const std::string& path)
std::shared_ptr<Function> import_onnx_model(const std::string& file_path)
{
std::ifstream ifs{path, std::ios::in | std::ios::binary};
std::ifstream ifs{file_path, std::ios::in | std::ios::binary};
if (!ifs.is_open())
{
throw detail::error::file_open{path};
throw detail::error::file_open{file_path};
}
return import_onnx_model(ifs);
}
void register_operator(const std::string& name,
std::int64_t version,
const std::string& domain,
Operator fn)
{
OperatorsBridge::register_operator(name, version, domain, std::move(fn));
}
std::set<std::string> get_supported_operators(std::int64_t version,
const std::string& domain)
{
......
......@@ -18,69 +18,68 @@
#include <cstdint>
#include <iostream>
#include <memory>
#include <set>
#include <string>
#include "core/operator_set.hpp"
#include "ngraph/function.hpp"
#include "ngraph/ngraph_visibility.hpp"
/// \brief Top level nGraph namespace.
namespace ngraph
{
/// \brief ONNX importer features namespace.
/// Functions in this namespace make it possible to use ONNX models.
namespace onnx_import
{
/// \brief Registers ONNX custom operator
/// Performs the registration of external ONNX operator. This means the code
/// of the operator is not part of ONNX importer. The operator shall be registered
/// before calling `load_onnx_model()` or `import_onnx_function()` functions.
/// \param name name of the operator,
/// \param version version of the operator (opset),
/// \param domain domain the operator belongs to,
/// \param fn function providing the implementation of the operator.
NGRAPH_API
void register_operator(const std::string& name,
std::int64_t version,
const std::string& domain,
Operator fn);
/// \brief Return the set of names of supported operators.
/// \brief Returns a set of names of supported operators
/// for the given opset version and domain.
///
/// \param[in] version The requested version of ONNX operators set.
/// \param[in] domain The requested domain the operators where registered for.
/// \param[in] version An opset version to get the supported operators for.
/// \param[in] domain A domain to get the supported operators for.
///
/// \return The set containing names of supported operators.
///
NGRAPH_API
std::set<std::string> get_supported_operators(std::int64_t version,
const std::string& domain);
/// \brief Determines whether ONNX operator is supported.
///
/// \param[in] op_name The ONNX operator name.
/// \param[in] version The ONNX operator set version.
/// \param[in] domain The domain the ONNX operator is registered to.
///
/// \return True if operator is supported, False otherwise.
/// \param[in] op_name The ONNX operator name.
/// \param[in] version The ONNX operator set version.
/// \param[in] domain The domain the ONNX operator is registered to.
/// If not set, the default domain "ai.onnx" is used.
///
/// \return true if operator is supported, false otherwise.
NGRAPH_API
bool is_operator_supported(const std::string& op_name,
std::int64_t version,
const std::string& domain = "ai.onnx");
/// \brief Convert an ONNX model to nGraph function
/// The function translated serialized ONNX model to nGraph function. The serialized
/// ONNX model is read from input stream.
/// \param sin input stream (e.g. file stream, memory stream, etc)
/// \return The function returns a nGraph function representing single output from graph.
/// \brief Imports and converts an serialized ONNX model from the input stream
/// to an nGraph Function representation.
///
/// \note If stream parsing fails or the ONNX model contains unsupported ops,
/// the function throws an ngraph_error exception.
///
/// \param[in] stream The input stream (e.g. file stream, memory stream, etc).
///
/// \return An nGraph function that represents a single output from the created graph.
NGRAPH_API
std::shared_ptr<Function> import_onnx_model(std::istream& sin);
std::shared_ptr<Function> import_onnx_model(std::istream& stream);
/// \brief Convert an ONNX model to nGraph functions
/// The function translated serialized ONNX model to nGraph functions. The ONNX model
/// is read from ONNX file.
/// \param filename file name (relative or absolute path name)
/// \return The function returns a nGraph function representing single output from graph.
/// \brief Imports and converts an ONNX model from the input file
/// to an nGraph Function representation.
///
/// \note If file parsing fails or the ONNX model contains unsupported ops,
/// the function throws an ngraph_error exception.
///
/// \param[in] file_path The path to a file containing the ONNX model
/// (relative or absolute).
///
/// \return An nGraph function that represents a single output from the created graph.
NGRAPH_API
std::shared_ptr<Function> import_onnx_model(const std::string& filename);
std::shared_ptr<Function> import_onnx_model(const std::string& file_path);
} // namespace onnx_import
......
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "onnx_utils.hpp"
#include "ops_bridge.hpp"
namespace ngraph
{
namespace onnx_import
{
void register_operator(const std::string& name,
std::int64_t version,
const std::string& domain,
Operator fn)
{
OperatorsBridge::register_operator(name, version, domain, std::move(fn));
}
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <cstdint>
#include <string>
#include "core/operator_set.hpp"
#include "ngraph/ngraph_visibility.hpp"
namespace ngraph
{
namespace onnx_import
{
/// \brief Registers ONNX custom operator.
/// The function performs the registration of external ONNX operator
/// which is not part of ONNX importer.
///
/// \note The operator must be registered before calling
/// "import_onnx_model" functions.
///
/// \param name The ONNX operator name.
/// \param version The ONNX operator set version.
/// \param domain The domain the ONNX operator is registered to.
/// \param fn The function providing the implementation of the operator
/// which transforms the single ONNX operator to an nGraph sub-graph.
NGRAPH_API
void register_operator(const std::string& name,
std::int64_t version,
const std::string& domain,
Operator fn);
} // namespace onnx_import
} // namespace ngraph
......@@ -36,6 +36,7 @@
#include "gtest/gtest.h"
#include "ngraph/frontend/onnx_import/onnx.hpp"
#include "ngraph/frontend/onnx_import/onnx_utils.hpp"
#include "ngraph/ngraph.hpp"
#include "util/all_close.hpp"
#include "util/all_close_f.hpp"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment