Unverified Commit afd8e51a authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Merge branch 'master' into bob/backend_api3

parents 94d93423 917efb94
......@@ -84,16 +84,13 @@ set(NGRAPH_FORWARD_CMAKE_ARGS
if (NOT MSVS)
if(NOT CMAKE_BUILD_TYPE)
set(NGRAPH_FORWARD_CMAKE_ARGS
${NGRAPH_FORWARD_CMAKE_ARGS}
-DCMAKE_BUILD_TYPE=Release
)
else()
set(NGRAPH_FORWARD_CMAKE_ARGS
${NGRAPH_FORWARD_CMAKE_ARGS}
-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
)
set(CMAKE_BUILD_TYPE "Release" CACHE STRING "Build type" FORCE)
endif()
set(NGRAPH_FORWARD_CMAKE_ARGS
${NGRAPH_FORWARD_CMAKE_ARGS}
-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
)
endif()
message(STATUS "NGRAPH_FORWARD_CMAKE_ARGS ${NGRAPH_FORWARD_CMAKE_ARGS}")
......@@ -339,7 +336,7 @@ else()
include(cmake/external_llvm.cmake)
endif()
if (WIN32)
if (WIN32 OR APPLE)
include(cmake/external_tbb_prebuilt.cmake)
else()
include(cmake/external_tbb.cmake)
......
......@@ -49,7 +49,7 @@ endif()
# This section sets up MKL as an external project to be used later by MKLDNN
set(MKLURLROOT "https://github.com/intel/mkl-dnn/releases/download/v0.17/")
set(MKLURLROOT "https://github.com/intel/mkl-dnn/releases/download/v0.17.2/")
set(MKLVERSION "2019.0.1.20180928")
if (${CMAKE_SYSTEM_NAME} STREQUAL "Linux")
set(MKLPACKAGE "mklml_lnx_${MKLVERSION}.tgz")
......@@ -90,7 +90,7 @@ set(MKL_LIBS ${TMP_PATHS})
target_link_libraries(libmkl INTERFACE ${MKL_LIBS})
set(MKLDNN_GIT_REPO_URL https://github.com/intel/mkl-dnn)
set(MKLDNN_GIT_TAG "830a100")
set(MKLDNN_GIT_TAG "b9ce57a")
if(NGRAPH_LIB_VERSIONING_ENABLE)
set(MKLDNN_PATCH_FILE mkldnn.patch)
else()
......
......@@ -16,10 +16,13 @@
include(ExternalProject)
set(ARCHIVE_FILE_BASE tbb2019_20181203oss)
if (WIN32)
set(ARCHIVE_FILE_BASE tbb2019_20181203oss)
set(TBB_FILE https://github.com/01org/tbb/releases/download/2019_U3/${ARCHIVE_FILE_BASE}_win.zip)
set(TBB_SHA1_HASH 1989458a49e780d76248edac13b963f80c9a460c)
elseif(APPLE)
set(TBB_FILE https://github.com/01org/tbb/releases/download/2019_U3/${ARCHIVE_FILE_BASE}_mac.tgz)
set(TBB_SHA1_HASH 36926fb46add578b88a5c7e19652b94bb612e4be)
endif()
ExternalProject_Add(
......@@ -37,11 +40,26 @@ ExternalProject_Add(
ExternalProject_Get_Property(ext_tbb SOURCE_DIR)
set(SOURCE_DIR ${SOURCE_DIR}/${ARCHIVE_FILE_BASE})
set(TBB_LINK_LIBS
${SOURCE_DIR}/lib/${CMAKE_SHARED_LIBRARY_PREFIX}clangTooling${CMAKE_SHARED_LIBRARY_SUFFIX}
${SOURCE_DIR}/lib/${CMAKE_SHARED_LIBRARY_PREFIX}clangTooling${CMAKE_SHARED_LIBRARY_SUFFIX}
${SOURCE_DIR}/lib/${CMAKE_SHARED_LIBRARY_PREFIX}clangTooling${CMAKE_SHARED_LIBRARY_SUFFIX}
)
if (WIN32)
set(TBB_LINK_LIBS
${SOURCE_DIR}/lib/${CMAKE_SHARED_LIBRARY_PREFIX}clangTooling${CMAKE_SHARED_LIBRARY_SUFFIX}
)
elseif(APPLE)
if(CMAKE_BUILD_TYPE STREQUAL "Debug")
set(TBB_LIB_NAME tbb_debug)
else()
set(TBB_LIB_NAME tbb)
endif()
set(TBB_LINK_LIBS
${NGRAPH_BUILD_DIR}/${CMAKE_SHARED_LIBRARY_PREFIX}${TBB_LIB_NAME}${CMAKE_SHARED_LIBRARY_SUFFIX}
)
add_custom_command(TARGET ext_tbb POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${SOURCE_DIR}/lib/${CMAKE_SHARED_LIBRARY_PREFIX}${TBB_LIB_NAME}${CMAKE_SHARED_LIBRARY_SUFFIX} ${NGRAPH_BUILD_DIR}
COMMENT "Move tbb libraries to ngraph build directory"
)
endif()
add_library(libtbb INTERFACE)
add_dependencies(libtbb ext_tbb)
......
......@@ -16,7 +16,7 @@ framework-based complexity and [import it] to test or run on targeted and
efficient backends with our user-friendly Python-based API.
nGraph is also integrated as an execution provider for [ONNX Runtime],
which is the first publicably available inference engine for ONNX.
which is the first publicly available inference engine for ONNX.
The table below summarizes our current progress on supported frameworks.
If you are an architect of a framework wishing to take advantage of speed
......
......@@ -145,6 +145,8 @@ add_library(onnx_import STATIC
op/tanh.hpp
op/thresholded_relu.cpp
op/thresholded_relu.hpp
op/topk.cpp
op/topk.hpp
op/transpose.cpp
op/transpose.hpp
op/unsqueeze.cpp
......
......@@ -22,7 +22,7 @@ namespace ngraph
{
namespace onnx_import
{
std::vector<Graph> Attribute::get_graph_array(const Model& model) const
std::vector<Graph> Attribute::get_graph_array(Model& model) const
{
std::vector<Graph> result;
for (const auto& graph : m_attribute_proto->graphs())
......@@ -32,7 +32,7 @@ namespace ngraph
return result;
}
Graph Attribute::get_graph(const Model& model) const
Graph Attribute::get_graph(Model& model) const
{
return Graph{m_attribute_proto->g(), model};
}
......
......@@ -278,7 +278,7 @@ namespace ngraph
float get_float() const { return m_attribute_proto->f(); }
int64_t get_integer() const { return m_attribute_proto->i(); }
const std::string& get_string() const { return m_attribute_proto->s(); }
Graph get_graph(const Model&) const;
Graph get_graph(Model&) const;
std::vector<Tensor> get_tensor_array() const
{
......@@ -303,7 +303,7 @@ namespace ngraph
std::end(m_attribute_proto->strings())};
}
std::vector<Graph> get_graph_array(const Model&) const;
std::vector<Graph> get_graph_array(Model&) const;
/* explicit */ operator onnx::AttributeProto_AttributeType() const
{
......
......@@ -14,6 +14,7 @@
// limitations under the License.
//*****************************************************************************
#include <functional>
#include <set>
#include "graph.hpp"
......@@ -25,26 +26,40 @@ namespace ngraph
{
namespace detail
{
std::string to_string(const std::set<std::string>& set)
static std::string to_string(
const std::map<std::string, std::reference_wrapper<const onnx::NodeProto>>& map)
{
std::string result;
for (auto it = std::begin(set); it != std::end(set); ++it)
for (auto it = std::begin(map); it != std::end(map); ++it)
{
result += (it != std::begin(set) ? ", " : "") + *it;
result += (it != std::begin(map) ? ", " : "") + it->first;
}
return result;
}
inline std::string to_string(const onnx::NodeProto& node_proto)
static std::string get_node_domain(const onnx::NodeProto& node_proto)
{
return (node_proto.domain().empty() ? "" : node_proto.domain() + ".") +
node_proto.op_type();
return (node_proto.domain().empty() ? "" : node_proto.domain());
}
}
Graph::Graph(const onnx::GraphProto& graph_proto,
const Model& model,
const Weights& weights)
/// \brief Gets the operator represented by provided node unique identificator.
///
/// \param[in] node_proto The node protobuf representation object.
///
/// \note The operator is uniquely identified by the tuple (domain, op_type,
/// since_version). The first two elements are stored in NodeProto object,
/// thus we use only them.
///
/// \return The unique identificator.
///
static std::string get_op_domain_and_name(const onnx::NodeProto& node_proto)
{
std::string domain = get_node_domain(node_proto);
return (domain.empty() ? "" : domain + ".") + node_proto.op_type();
}
} // namespace detail
Graph::Graph(const onnx::GraphProto& graph_proto, Model& model, const Weights& weights)
: m_graph_proto{&graph_proto}
, m_model{&model}
{
......@@ -70,17 +85,34 @@ namespace ngraph
}
// Verify that ONNX graph contains only nodes of available operator types
std::set<std::string> unknown_operator_types;
std::map<std::string, std::reference_wrapper<const onnx::NodeProto>> unknown_operators;
for (const auto& node_proto : m_graph_proto->node())
{
if (!m_model->is_operator_available(node_proto))
{
unknown_operator_types.emplace(detail::to_string(node_proto));
unknown_operators.emplace(detail::get_op_domain_and_name(node_proto),
node_proto);
// Try adding missing domain
m_model->enable_opset_domain(detail::get_node_domain(node_proto));
}
}
// Reverify wheter we still have any unavailable operators.
auto it = std::begin(unknown_operators);
while (it != std::end(unknown_operators))
{
if (m_model->is_operator_available(it->second))
{
it = unknown_operators.erase(it);
}
else
{
it++;
}
}
NGRAPH_ASSERT(unknown_operator_types.empty())
<< "unknown operations: " << detail::to_string(unknown_operator_types);
NGRAPH_ASSERT(unknown_operators.empty()) << "unknown operations: "
<< detail::to_string(unknown_operators);
// Process ONNX graph nodes, convert to nGraph nodes
for (const auto& node_proto : m_graph_proto->node())
......
......@@ -33,7 +33,7 @@ namespace ngraph
class Graph
{
public:
Graph(const onnx::GraphProto& proto, const Model& model, const Weights& weights = {});
Graph(const onnx::GraphProto& proto, Model& model, const Weights& weights = {});
const std::vector<Node>& get_nodes() const { return m_nodes; }
const std::vector<ValueInfo>& get_inputs() const { return m_inputs; }
......@@ -59,7 +59,7 @@ namespace ngraph
ParameterVector m_parameters;
std::map<std::string, std::shared_ptr<ngraph::Node>> m_ng_node_cache;
std::map<std::string, Tensor> m_initializers;
const Model* m_model;
Model* m_model;
};
inline std::ostream& operator<<(std::ostream& outs, const Graph& graph)
......
......@@ -17,6 +17,7 @@
#include <onnx-ml.pb.h>
#include "model.hpp"
#include "ngraph/log.hpp"
#include "ops_bridge.hpp"
namespace ngraph
......@@ -33,14 +34,14 @@ namespace ngraph
{
m_opset.emplace(id.domain(),
OperatorsBridge::get_operator_set(
id.version(), (id.domain() == "ai.onnx" ? "" : id.domain())));
(id.domain() == "ai.onnx" ? "" : id.domain()), id.version()));
}
// onnx.proto(.3): the empty string ("") for domain or absence of opset_import field
// implies the operator set that is defined as part of the ONNX specification.
const auto dm = m_opset.find("");
if (dm == std::end(m_opset))
{
m_opset.emplace("", OperatorsBridge::get_operator_set(ONNX_OPSET_VERSION, ""));
m_opset.emplace("", OperatorsBridge::get_operator_set("", ONNX_OPSET_VERSION));
}
}
......@@ -71,6 +72,26 @@ namespace ngraph
return (op != std::end(dm->second));
}
void Model::enable_opset_domain(const std::string& domain)
{
// There is no need to 'update' already enabled domain.
// Since this function may be called only during model import,
// (maybe multiple times) the registered domain opset won't differ
// between subsequent calls.
if (m_opset.find(domain) == std::end(m_opset))
{
OperatorSet opset{OperatorsBridge::get_operator_set(domain)};
if (opset.empty())
{
NGRAPH_WARN << "Couldn't enable domain: " << domain
<< " since it hasn't any registered operators.";
return;
}
m_opset.emplace(domain, opset);
}
}
} // namespace onnx_import
} // namespace ngraph
......@@ -61,6 +61,15 @@ namespace ngraph
/// \return `true` if the operator is available, otherwise it returns `false`.
bool is_operator_available(const onnx::NodeProto& node_proto) const;
/// \brief Enable operators from provided domain to use by this model.
///
/// \note This function makes visible all currently registered in provided domain
/// operators for use in this model.
///
/// \param[in] domain The domain name.
///
void enable_opset_domain(const std::string& domain);
private:
const onnx::ModelProto* m_model_proto;
std::unordered_map<std::string, OperatorSet> m_opset;
......
......@@ -181,6 +181,34 @@ namespace ngraph
throw error::tensor::invalid_data_type{tensor.data_type()};
}
template <>
inline std::vector<int8_t> get_data(const onnx::TensorProto& tensor)
{
if (tensor.has_raw_data())
{
return detail::__get_raw_data<int8_t>(tensor.raw_data());
}
if (tensor.data_type() == onnx::TensorProto_DataType_INT8)
{
return detail::__get_data<int8_t>(tensor.int32_data());
}
throw error::tensor::invalid_data_type{tensor.data_type()};
}
template <>
inline std::vector<int16_t> get_data(const onnx::TensorProto& tensor)
{
if (tensor.has_raw_data())
{
return detail::__get_raw_data<int16_t>(tensor.raw_data());
}
if (tensor.data_type() == onnx::TensorProto_DataType_INT16)
{
return detail::__get_data<int16_t>(tensor.int32_data());
}
throw error::tensor::invalid_data_type{tensor.data_type()};
}
template <>
inline std::vector<int32_t> get_data(const onnx::TensorProto& tensor)
{
......@@ -209,6 +237,48 @@ namespace ngraph
return detail::__get_data<int64_t>(tensor.int64_data());
}
template <>
inline std::vector<uint8_t> get_data(const onnx::TensorProto& tensor)
{
if (tensor.has_raw_data())
{
return detail::__get_raw_data<uint8_t>(tensor.raw_data());
}
if (tensor.data_type() == onnx::TensorProto_DataType_UINT8)
{
return detail::__get_data<uint8_t>(tensor.int32_data());
}
throw error::tensor::invalid_data_type{tensor.data_type()};
}
template <>
inline std::vector<uint16_t> get_data(const onnx::TensorProto& tensor)
{
if (tensor.has_raw_data())
{
return detail::__get_raw_data<uint16_t>(tensor.raw_data());
}
if (tensor.data_type() == onnx::TensorProto_DataType_UINT16)
{
return detail::__get_data<uint16_t>(tensor.int32_data());
}
throw error::tensor::invalid_data_type{tensor.data_type()};
}
template <>
inline std::vector<uint32_t> get_data(const onnx::TensorProto& tensor)
{
if (tensor.has_raw_data())
{
return detail::__get_raw_data<uint32_t>(tensor.raw_data());
}
if (tensor.data_type() == onnx::TensorProto_DataType_UINT32)
{
return detail::__get_data<uint32_t>(tensor.uint64_data());
}
throw error::tensor::invalid_data_type{tensor.data_type()};
}
template <>
inline std::vector<uint64_t> get_data(const onnx::TensorProto& tensor)
{
......
......@@ -90,7 +90,8 @@ namespace ngraph
std::set<std::string> get_supported_operators(std::int64_t version,
const std::string& domain)
{
OperatorSet op_set{OperatorsBridge::get_operator_set(version, domain)};
OperatorSet op_set{
OperatorsBridge::get_operator_set(domain == "ai.onnx" ? "" : domain, version)};
std::set<std::string> op_list{};
for (const auto& op : op_set)
{
......
......@@ -65,6 +65,20 @@ namespace ngraph
return __make_ng_constant<double>(element::f64, tensor);
}
template <>
inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::int8>(const Tensor& tensor)
{
return __make_ng_constant<int8_t>(element::i8, tensor);
}
template <>
inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::int16>(const Tensor& tensor)
{
return __make_ng_constant<int16_t>(element::i16, tensor);
}
template <>
inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::int32>(const Tensor& tensor)
......@@ -79,6 +93,20 @@ namespace ngraph
return __make_ng_constant<int64_t>(element::i64, tensor);
}
template <>
inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::uint8>(const Tensor& tensor)
{
return __make_ng_constant<uint8_t>(element::u8, tensor);
}
template <>
inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::uint16>(const Tensor& tensor)
{
return __make_ng_constant<uint16_t>(element::u16, tensor);
}
template <>
inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::uint32>(const Tensor& tensor)
......@@ -103,8 +131,12 @@ namespace ngraph
MAKE_NG_CONSTANT(Tensor::Type::float16);
MAKE_NG_CONSTANT(Tensor::Type::float32);
MAKE_NG_CONSTANT(Tensor::Type::float64);
MAKE_NG_CONSTANT(Tensor::Type::int8);
MAKE_NG_CONSTANT(Tensor::Type::int16);
MAKE_NG_CONSTANT(Tensor::Type::int32);
MAKE_NG_CONSTANT(Tensor::Type::int64);
MAKE_NG_CONSTANT(Tensor::Type::uint8);
MAKE_NG_CONSTANT(Tensor::Type::uint16);
MAKE_NG_CONSTANT(Tensor::Type::uint32);
MAKE_NG_CONSTANT(Tensor::Type::uint64);
default: throw error::tensor::invalid_data_type{tensor};
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <cstdint>
#include <memory>
#include <vector>
#include "exceptions.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/topk.hpp"
#include "ngraph/type/element_type.hpp"
#include "topk.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector topk(const Node& node)
{
auto data = node.get_ng_inputs().at(0);
std::int64_t axis{node.get_attribute_value<std::int64_t>("axis", -1)};
std::int64_t k{node.get_attribute_value<std::int64_t>("k")};
auto num_dimensions = data->get_shape().size();
if (axis < 0)
{
axis += num_dimensions;
}
ASSERT_VALID_ARGUMENT(node, axis < num_dimensions)
<< "`axis` parameter is out of range: " << axis;
std::shared_ptr<ngraph::Node> top_k =
std::make_shared<ngraph::op::TopK>(data, axis, element::i64, k);
std::shared_ptr<ngraph::Node> indices =
std::make_shared<ngraph::op::GetOutputElement>(top_k, 0);
std::shared_ptr<ngraph::Node> values =
std::make_shared<ngraph::op::GetOutputElement>(top_k, 1);
return {values, indices};
}
} // namespace set_1
} //namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "core/node.hpp"
#include "ngraph/node_vector.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
/// \brief Performs ONNX TopK operation.
///
/// \param node The ONNX node object representing this operation.
/// \return The vector containing Ngraph nodes producing output of ONNX TopK
/// operation(both values and indices).
NodeVector topk(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import
} // namespace ngraph
......@@ -94,6 +94,7 @@
#include "op/tan.hpp"
#include "op/tanh.hpp"
#include "op/thresholded_relu.hpp"
#include "op/topk.hpp"
#include "op/transpose.hpp"
#include "op/unsqueeze.hpp"
#include "op/xor.hpp"
......@@ -109,6 +110,11 @@ namespace ngraph
find(std::int64_t version, const std::map<std::int64_t, Operator>& map)
{
std::map<std::int64_t, Operator>::const_iterator it{};
// Get the latest version.
if (version == -1)
{
return map.empty() ? std::end(map) : --std::end(map);
}
while (version > 0)
{
it = map.find(version--);
......@@ -126,23 +132,29 @@ namespace ngraph
const std::string& domain,
Operator fn)
{
m_map[domain][name].emplace(version, std::move(fn));
auto result = m_map[domain][name].emplace(version, std::move(fn));
if (result.second)
{
NGRAPH_WARN << "Overwriting existing operator: "
<< domain + "." + name + ":" + std::to_string(version);
}
}
OperatorSet OperatorsBridge::_get_operator_set(std::int64_t version,
const std::string& domain)
OperatorSet OperatorsBridge::_get_operator_set(const std::string& domain,
std::int64_t version)
{
OperatorSet result;
auto dm = m_map.find(domain);
if (dm == std::end(m_map))
{
throw error::UnknownDomain{domain};
}
if (version > OperatorsBridge::LATEST_SUPPORTED_OPSET_VERSION)
if (domain == "" && version > OperatorsBridge::LATEST_SUPPORTED_ONNX_OPSET_VERSION)
{
NGRAPH_WARN << "Currently operator set version: " << version << " is unsupported."
<< " Falling back to: "
<< OperatorsBridge::LATEST_SUPPORTED_OPSET_VERSION;
NGRAPH_WARN << "Currently ONNX operator set version: " << version
<< " is unsupported. Falling back to: "
<< OperatorsBridge::LATEST_SUPPORTED_ONNX_OPSET_VERSION;
}
for (const auto& op : dm->second)
{
......@@ -277,6 +289,7 @@ namespace ngraph
REGISTER_OPERATOR("Tan", 1, tan);
REGISTER_OPERATOR("Tanh", 1, tanh);
REGISTER_OPERATOR("ThresholdedRelu", 1, thresholded_relu);
REGISTER_OPERATOR("TopK", 1, topk);
REGISTER_OPERATOR("Transpose", 1, transpose);
REGISTER_OPERATOR("Unsqueeze", 1, unsqueeze);
REGISTER_OPERATOR("Xor", 1, logical_xor);
......
......@@ -62,16 +62,17 @@ namespace ngraph
class OperatorsBridge
{
public:
static constexpr const int LATEST_SUPPORTED_OPSET_VERSION = ONNX_OPSET_VERSION;
static constexpr const int LATEST_SUPPORTED_ONNX_OPSET_VERSION = ONNX_OPSET_VERSION;
OperatorsBridge(const OperatorsBridge&) = delete;
OperatorsBridge& operator=(const OperatorsBridge&) = delete;
OperatorsBridge(OperatorsBridge&&) = delete;
OperatorsBridge& operator=(OperatorsBridge&&) = delete;
static OperatorSet get_operator_set(std::int64_t version, const std::string& domain)
static OperatorSet get_operator_set(const std::string& domain,
std::int64_t version = -1)
{
return instance()._get_operator_set(version, domain);
return instance()._get_operator_set(domain, version);
}
static void register_operator(const std::string& name,
......@@ -90,6 +91,20 @@ namespace ngraph
}
private:
// Registered operators structure
// {
// domain_1: {
// op_type_1: {
// version_1: {func_handle},
// version_2: {func_handle},
// ...
// },
// op_type_2: { ... }
// ...
// },
// domain_2: { ... },
// ...
// }
std::unordered_map<std::string,
std::unordered_map<std::string, std::map<std::int64_t, Operator>>>
m_map;
......@@ -106,7 +121,8 @@ namespace ngraph
std::int64_t version,
const std::string& domain,
Operator fn);
OperatorSet _get_operator_set(std::int64_t version, const std::string& domain);
OperatorSet _get_operator_set(const std::string& domain, std::int64_t version);
bool _is_operator_registered(const std::string& name,
std::int64_t version,
const std::string& domain);
......
......@@ -191,3 +191,26 @@ void Function::replace_node(std::shared_ptr<Node> old, std::shared_ptr<Node> rep
{
ngraph::replace_node(old, repl);
}
size_t Function::get_graph_size() const
{
size_t total_size = 0;
for (auto node : get_ops())
{
total_size += sizeof(*node);
if (node->description() == "Constant")
{
const Shape& shape = node->get_outputs()[0].get_shape();
size_t const_size = node->get_outputs()[0].get_element_type().size();
if (shape.size() == 0)
{
total_size += const_size;
}
else
{
total_size += (const_size * shape_size(node->get_outputs()[0].get_shape()));
}
}
}
return total_size;
}
......@@ -85,6 +85,11 @@ namespace ngraph
void validate_nodes_and_infer_types();
/// \brief Returns the sum of the size of all nodes in the graph plus the size of
/// all constant data. This has little value beyond comparing the relative size of
/// graphs and should not be considered the actual memory consumption of a graph.
size_t get_graph_size() const;
protected:
ResultVector m_results;
ParameterVector m_parameters;
......
......@@ -491,27 +491,3 @@ void Node::validate_and_infer_elementwise_logical()
set_output_type(0, element::boolean, args_pshape);
}
bool Node::validate_punt_if_dynamic()
{
bool any_dynamic = false;
for (auto& input : m_inputs)
{
any_dynamic |= input.get_partial_shape().is_dynamic();
any_dynamic |= input.get_element_type().is_dynamic();
}
if (any_dynamic)
{
for (size_t i = 0; i < get_output_size(); i++)
{
set_output_type(i, element::dynamic, PartialShape::dynamic());
}
return true;
}
else
{
return false;
}
}
......@@ -99,17 +99,6 @@ namespace ngraph
void validate_and_infer_elementwise_arithmetic();
void validate_and_infer_elementwise_logical();
// Temporary hack while partial shape propagation is being implemented. If any input has
// dynamic shape or dynamic element type, sets all outputs to have a shape of dynamic
// rank and dynamic element type. Ops where we haven't yet implemented partial shape
// propagation can add this boilerplate at the top of their validate_and_infer_types():
//
// if (validate_punt_if_dynamic())
// {
// return;
// }
bool validate_punt_if_dynamic();
Node(const std::string& node_type, const NodeVector& arguments, size_t output_size = 1);
virtual void generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas) {}
......
......@@ -45,6 +45,11 @@
//
// It's that easy. You can use this for fun and profit.
#ifndef NGRAPH_OP
#error "NGRAPH_OP not defined"
#define NGRAPH_OP(x, y)
#endif
NGRAPH_OP(Abs, ngraph::op)
NGRAPH_OP(Acos, ngraph::op)
NGRAPH_OP(Add, ngraph::op)
......@@ -103,6 +108,13 @@ NGRAPH_OP(Parameter, ngraph::op)
NGRAPH_OP(Power, ngraph::op)
NGRAPH_OP(Product, ngraph::op)
NGRAPH_OP(Quantize, ngraph::op)
NGRAPH_OP(QuantizedAvgPool, ngraph::op)
NGRAPH_OP(QuantizedConvolutionBias, ngraph::op)
NGRAPH_OP(QuantizedConvolutionBiasAdd, ngraph::op)
NGRAPH_OP(QuantizedConvolutionBiasSignedAdd, ngraph::op)
NGRAPH_OP(QuantizedConvolutionRelu, ngraph::op)
NGRAPH_OP(QuantizedConvolution, ngraph::op)
NGRAPH_OP(QuantizedMaxPool, ngraph::op)
NGRAPH_OP(Relu, ngraph::op)
NGRAPH_OP(ReluBackprop, ngraph::op)
NGRAPH_OP(ReplaceSlice, ngraph::op)
......
......@@ -529,8 +529,11 @@ bool ngraph::pass::ReshapeSinking::run_on_function(std::shared_ptr<ngraph::Funct
// scenarios and marks some reshapes as too "toxic" to sink
// For now, this heuristic works really well.
// Note, get_users(*true*) which means we only care about
// live users of Reshape
if (slice->get_argument(0)->get_users(true).size() == 1)
// live users of Reshape. However get_users(*true*) cause
// significant time increase on graphs with many slice ops,
// so for now we are removing "true" check and let backend
// handle reshape sinking for slice operation.
if (slice->get_argument(0)->get_users().size() == 1)
{
sink_slice(slice, reorders, reshapes_to_delete);
}
......
......@@ -133,7 +133,7 @@ if (NGRAPH_HALIDE)
)
endif()
if (NGRAPH_TBB_ENABLE AND NOT WIN32)
if (NGRAPH_TBB_ENABLE AND NOT (WIN32 OR APPLE))
include(${TBB_ROOT}/cmake/TBBBuild.cmake)
tbb_build(TBB_ROOT ${TBB_ROOT} MAKE_ARGS tbb_build_dir=${CMAKE_CURRENT_BINARY_DIR}/tbb_build
tbb_build_prefix=tbb CONFIG_DIR TBB_DIR)
......
......@@ -306,15 +306,41 @@ bool runtime::cpu::pass::CPURnnMatFusion::run_on_function(std::shared_ptr<Functi
NodeVector params = p.first;
NodeVector& op_nodes = p.second;
auto data_node = params.at(Type::DATA);
// we will sort the captured Add(Dot(X, W) + B) as per the the slice ordering of X
// this will simplify the replace_node logic
auto compare_slices = [&](const std::shared_ptr<Node> node1,
const std::shared_ptr<Node> node2) {
const auto node1_slice =
std::static_pointer_cast<op::Slice>(op_seg_map[node1].at(Type::DATA));
const auto node2_slice =
std::static_pointer_cast<op::Slice>(op_seg_map[node2].at(Type::DATA));
return (node1_slice->get_lower_bounds() < node2_slice->get_lower_bounds() &&
node1_slice->get_upper_bounds() < node2_slice->get_upper_bounds());
};
std::sort(op_nodes.begin(), op_nodes.end(), compare_slices);
// we fuse all the data slices captured in the pattern to make bigger GEMM call
auto fuse_data_slices = [&]() {
NodeVector data_slices;
for (auto& op : op_nodes)
{
auto data_node = op_seg_map.at(op).at(Type::DATA);
data_slices.push_back(data_node);
}
return std::make_shared<op::Concat>(data_slices, 0);
};
auto data_node = op_nodes.size() > 1 ? fuse_data_slices() : params.at(Type::DATA);
auto weights_node = params.at(Type::WEIGHTS);
auto bias_node = params.at(Type::BIAS);
auto& data_shape = data_node->get_shape();
const auto& data_shape = data_node->get_shape();
// construct new op nodes
auto data_order = ngraph::get_default_order(data_node->get_shape());
auto data_reshape_node = std::make_shared<op::Reshape>(
data_node, data_order, Shape{data_shape[0] * data_shape[1], data_shape[2]});
auto data_reshape_node =
std::make_shared<op::Reshape>(data_node,
AxisVector{0, 1, 2},
Shape{data_shape[0] * data_shape[1], data_shape[2]});
auto old_weights_reshape_node = op_seg_map.at(op_nodes.at(0)).at(Type::WEIGHTS);
auto weights_reshape_node =
......@@ -327,30 +353,16 @@ bool runtime::cpu::pass::CPURnnMatFusion::run_on_function(std::shared_ptr<Functi
auto add_node = std::make_shared<op::Add>(dot_node, bias_broadcast_node);
const auto& add_shape = add_node->get_shape();
// we will sort the captured Add(Dot(X, W) + B) as per the the slice ordering of X
// this will simplify the replace_node logic
auto compare_slices = [&](const std::shared_ptr<Node> node1,
const std::shared_ptr<Node> node2) {
const auto node1_slice =
std::static_pointer_cast<op::Slice>(op_seg_map[node1].at(Type::DATA));
const auto node2_slice =
std::static_pointer_cast<op::Slice>(op_seg_map[node2].at(Type::DATA));
return (node1_slice->get_lower_bounds() < node2_slice->get_lower_bounds() &&
node1_slice->get_upper_bounds() < node2_slice->get_upper_bounds());
};
std::sort(op_nodes.begin(), op_nodes.end(), compare_slices);
size_t num_timesteps = op_nodes.size();
size_t batch_size = add_shape[0] / num_timesteps;
size_t feature_size = add_shape[1];
// create a slice for each user of the dot op matching the original dot op's output
for (size_t i = 0, start_index = 0; i < op_nodes.size(); i++, start_index += batch_size)
{
// calculate the lower and upper bounds for the slice of the new fused node
// ((<x0 | x1..|xt>*W)+b), which will used to replace the nodes matched in the pattern
const Coordinate lower_bounds{start_index, 0};
const Coordinate upper_bounds{start_index + batch_size, add_shape[1]};
const Coordinate upper_bounds{start_index + batch_size, feature_size};
auto slice_node = std::make_shared<op::Slice>(add_node, lower_bounds, upper_bounds);
......
......@@ -58,6 +58,11 @@
#include "ngraph/op/equal.hpp"
#include "ngraph/op/exp.hpp"
#include "ngraph/op/experimental/generate_mask.hpp"
#include "ngraph/op/experimental/quantized_avg_pool.hpp"
#include "ngraph/op/experimental/quantized_conv.hpp"
#include "ngraph/op/experimental/quantized_conv_bias.hpp"
#include "ngraph/op/experimental/quantized_conv_relu.hpp"
#include "ngraph/op/experimental/quantized_max_pool.hpp"
#include "ngraph/op/experimental/shape_of.hpp"
#include "ngraph/op/floor.hpp"
#include "ngraph/op/get_output_element.hpp"
......@@ -880,6 +885,41 @@ std::string runtime::gpu::GPU_Emitter::emit_Quantize(EMIT_ARGS)
throw unsupported_op("Unsupported op '" + node->description() + "'");
}
std::string runtime::gpu::GPU_Emitter::emit_QuantizedAvgPool(EMIT_ARGS)
{
throw unsupported_op("Unsupported op '" + node->description() + "'");
}
std::string runtime::gpu::GPU_Emitter::emit_QuantizedConvolution(EMIT_ARGS)
{
throw unsupported_op("Unsupported op '" + node->description() + "'");
}
std::string runtime::gpu::GPU_Emitter::emit_QuantizedConvolutionBias(EMIT_ARGS)
{
throw unsupported_op("Unsupported op '" + node->description() + "'");
}
std::string runtime::gpu::GPU_Emitter::emit_QuantizedConvolutionBiasAdd(EMIT_ARGS)
{
throw unsupported_op("Unsupported op '" + node->description() + "'");
}
std::string runtime::gpu::GPU_Emitter::emit_QuantizedConvolutionBiasSignedAdd(EMIT_ARGS)
{
throw unsupported_op("Unsupported op '" + node->description() + "'");
}
std::string runtime::gpu::GPU_Emitter::emit_QuantizedConvolutionRelu(EMIT_ARGS)
{
throw unsupported_op("Unsupported op '" + node->description() + "'");
}
std::string runtime::gpu::GPU_Emitter::emit_QuantizedMaxPool(EMIT_ARGS)
{
throw unsupported_op("Unsupported op '" + node->description() + "'");
}
std::string runtime::gpu::GPU_Emitter::emit_Relu(EMIT_ARGS)
{
return emit_elementwise<ngraph::op::Relu>(compiled_function, function_name, node, args, out);
......
......@@ -1781,6 +1781,13 @@ runtime::intelgpu::IntelGPUExecutable::IntelGPUExecutable(shared_ptr<Function> f
case OP_TYPEID::BroadcastLike:
case OP_TYPEID::Dequantize:
case OP_TYPEID::Quantize:
case OP_TYPEID::QuantizedAvgPool:
case OP_TYPEID::QuantizedConvolutionBias:
case OP_TYPEID::QuantizedConvolutionBiasAdd:
case OP_TYPEID::QuantizedConvolutionBiasSignedAdd:
case OP_TYPEID::QuantizedConvolutionRelu:
case OP_TYPEID::QuantizedConvolution:
case OP_TYPEID::QuantizedMaxPool:
case OP_TYPEID::ReplaceSlice:
case OP_TYPEID::GenerateMask:
case OP_TYPEID::ReverseSequence:
......
......@@ -145,8 +145,8 @@ namespace ngraph
class INTBackend;
class INTExecutable;
}
}
}
} // namespace runtime
} // namespace ngraph
class ngraph::runtime::interpreter::INTBackend : public Backend
{
......@@ -1024,6 +1024,17 @@ private:
break;
}
case OP_TYPEID::QuantizedAvgPool:
case OP_TYPEID::QuantizedConvolutionBias:
case OP_TYPEID::QuantizedConvolutionBiasAdd:
case OP_TYPEID::QuantizedConvolutionBiasSignedAdd:
case OP_TYPEID::QuantizedConvolutionRelu:
case OP_TYPEID::QuantizedConvolution:
case OP_TYPEID::QuantizedMaxPool:
{
throw unsupported_op("Unsupported op '" + node.description() +
"' in Interpreter back end.");
}
case OP_TYPEID::Relu:
{
size_t element_count = shape_size(node.get_output_shape(0));
......
......@@ -48,6 +48,11 @@
#include "ngraph/op/equal.hpp"
#include "ngraph/op/exp.hpp"
#include "ngraph/op/experimental/generate_mask.hpp"
#include "ngraph/op/experimental/quantized_avg_pool.hpp"
#include "ngraph/op/experimental/quantized_conv.hpp"
#include "ngraph/op/experimental/quantized_conv_bias.hpp"
#include "ngraph/op/experimental/quantized_conv_relu.hpp"
#include "ngraph/op/experimental/quantized_max_pool.hpp"
#include "ngraph/op/experimental/shape_of.hpp"
#include "ngraph/op/floor.hpp"
#include "ngraph/op/get_output_element.hpp"
......@@ -957,6 +962,66 @@ static shared_ptr<ngraph::Function>
node = make_shared<op::Quantize>(args[0], args[1], args[2], type, axes, round_mode);
break;
}
case OP_TYPEID::QuantizedAvgPool:
{
auto window_shape = node_js.at("window_shape").get<vector<size_t>>();
auto window_movement_strides =
node_js.at("window_movement_strides").get<vector<size_t>>();
auto padding_below = node_js.at("padding_below").get<vector<size_t>>();
auto padding_above = node_js.at("padding_above").get<vector<size_t>>();
auto include_padding_in_avg_computation =
node_js.at("include_padding_in_avg_computation").get<bool>();
node = make_shared<op::QuantizedAvgPool>(args[0],
window_shape,
window_movement_strides,
padding_below,
padding_above,
include_padding_in_avg_computation);
break;
}
case OP_TYPEID::QuantizedConvolutionBias: { break;
}
case OP_TYPEID::QuantizedConvolutionBiasAdd: { break;
}
case OP_TYPEID::QuantizedConvolutionBiasSignedAdd: { break;
}
case OP_TYPEID::QuantizedConvolutionRelu: { break;
}
case OP_TYPEID::QuantizedConvolution:
{
auto window_movement_strides =
node_js.at("window_movement_strides").get<vector<size_t>>();
auto window_dilation_strides =
node_js.at("window_dilation_strides").get<vector<size_t>>();
auto padding_below = node_js.at("padding_below").get<vector<std::ptrdiff_t>>();
auto padding_above = node_js.at("padding_above").get<vector<std::ptrdiff_t>>();
auto data_dilation_strides = node_js["data_dilation_strides"];
node =
make_shared<op::Convolution>(args[0],
args[1],
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides.get<std::vector<size_t>>());
break;
}
case OP_TYPEID::QuantizedMaxPool:
{
auto window_shape = node_js.at("window_shape").get<vector<size_t>>();
auto window_movement_strides =
node_js.at("window_movement_strides").get<vector<size_t>>();
// For backwards compatibility, both (but not just one) of the padding_ fields may be
// omitted.
auto padding_below_maybe = node_js["padding_below"];
auto padding_above_maybe = node_js["padding_above"];
auto padding_below = padding_below_maybe.get<vector<size_t>>();
auto padding_above = padding_above_maybe.get<vector<size_t>>();
node = make_shared<op::QuantizedMaxPool>(
args[0], window_shape, window_movement_strides, padding_below, padding_above);
break;
}
case OP_TYPEID::Relu:
{
node = make_shared<op::Relu>(args[0]);
......@@ -1507,6 +1572,43 @@ static json write(const Node& n, bool binary_constant_data)
node["round_mode"] = tmp->get_round_mode();
break;
}
case OP_TYPEID::QuantizedAvgPool:
{
auto tmp = dynamic_cast<const op::QuantizedAvgPool*>(&n);
node["window_shape"] = tmp->get_window_shape();
node["window_movement_strides"] = tmp->get_window_movement_strides();
node["padding_below"] = tmp->get_padding_below();
node["padding_above"] = tmp->get_padding_above();
node["include_padding_in_avg_computation"] = tmp->get_include_padding_in_avg_computation();
break;
}
case OP_TYPEID::QuantizedConvolutionBias: { break;
}
case OP_TYPEID::QuantizedConvolutionBiasAdd: { break;
}
case OP_TYPEID::QuantizedConvolutionBiasSignedAdd: { break;
}
case OP_TYPEID::QuantizedConvolutionRelu: { break;
}
case OP_TYPEID::QuantizedConvolution:
{
auto tmp = dynamic_cast<const op::QuantizedConvolution*>(&n);
node["window_movement_strides"] = tmp->get_window_movement_strides();
node["window_dilation_strides"] = tmp->get_window_dilation_strides();
node["padding_below"] = tmp->get_padding_below();
node["padding_above"] = tmp->get_padding_above();
node["data_dilation_strides"] = tmp->get_data_dilation_strides();
break;
}
case OP_TYPEID::QuantizedMaxPool:
{
auto tmp = dynamic_cast<const op::QuantizedMaxPool*>(&n);
node["window_shape"] = tmp->get_window_shape();
node["window_movement_strides"] = tmp->get_window_movement_strides();
node["padding_below"] = tmp->get_padding_below();
node["padding_above"] = tmp->get_padding_above();
break;
}
case OP_TYPEID::Relu: { break;
}
case OP_TYPEID::ReluBackprop: { break;
......
......@@ -32,6 +32,7 @@ set(SRC
control_dependencies.cpp
coordinate.cpp
copy.cpp
core.cpp
cpio.cpp
cse.cpp
element_type.cpp
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "gtest/gtest.h"
#include "ngraph/file_util.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/serializer.hpp"
using namespace ngraph;
using namespace std;
TEST(core, function_size)
{
const string m1 = file_util::path_join(SERIALIZED_ZOO, "mxnet/mnist_mlp_forward.json");
const string m2 = file_util::path_join(SERIALIZED_ZOO, "mxnet/10_bucket_LSTM.json");
auto f1 = deserialize(m1);
auto f2 = deserialize(m2);
auto s1 = f1->get_graph_size();
auto s2 = f2->get_graph_size();
EXPECT_GT(s2, s1);
}
......@@ -3450,3 +3450,26 @@ TEST(cpu_fusion, rnn_input_fusion_inter_vs_cpu)
EXPECT_TRUE(test::all_close(cpu_results.at(i), int_results.at(i), 1.0e-4f, 1.0e-4f));
}
}
TEST(cpu_fusion, validate_fuse_gru_inputs)
{
const std::string file_name("mxnet/gru_debug.json");
auto cpu_func = make_function_from_file(file_name);
auto int_func = make_function_from_file(file_name);
test::Uniform<float> rng(-10.0f, 10.0f);
vector<vector<float>> args;
for (shared_ptr<op::Parameter> param : int_func->get_parameters())
{
vector<float> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
auto int_results = execute(int_func, args, "INTERPRETER");
auto cpu_results = execute(cpu_func, args, "CPU");
for (size_t i = 0; i < cpu_results.size(); i++)
{
EXPECT_TRUE(test::all_close(cpu_results.at(i), int_results.at(i), 1.0e-4f, 1.0e-4f));
}
}
This diff is collapsed.
ONNXnGraphImporter:o

A
BC" CustomAdd: custom.op compute_graphZ
A


Z
B


b
C


B
\ No newline at end of file
 backend-test:‰
1
xvaluesindices"TopK*
k *
axis 
test_top_kZ
x


b
values


b
indices


B
\ No newline at end of file
......@@ -1819,3 +1819,47 @@ TEST(onnx_${BACKEND_NAME}, model_space_to_depth_no_blocksize)
file_util::path_join(SERIALIZED_ZOO, "onnx/space_to_depth_no_blocksize.onnx")),
std::runtime_error);
}
TEST(onnx_${BACKEND_NAME}, model_missing_op_domain)
{
onnx_import::register_operator(
"CustomAdd", 1, "custom.op", [](const onnx_import::Node& node) -> NodeVector {
NodeVector ng_inputs{node.get_ng_inputs()};
return {std::make_shared<ngraph::op::Add>(ng_inputs.at(0), ng_inputs.at(1))};
});
EXPECT_TRUE(onnx_import::is_operator_supported("CustomAdd", 1, "custom.op"));
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/missing_op_domain.onnx"));
Inputs inputs;
inputs.emplace_back(std::vector<float>{0.f, 1.f, 2.f, 3.f});
inputs.emplace_back(std::vector<float>{0.f, 1.f, 2.f, 3.f});
Outputs expected_output{std::vector<float>{0.f, 2.f, 4.f, 6.f}};
Outputs outputs{execute(function, inputs, "${BACKEND_NAME}")};
EXPECT_TRUE(test::all_close_f(expected_output.front(), outputs.front()));
}
TEST(onnx_${BACKEND_NAME}, model_top_k)
{
auto function =
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/top_k.onnx"));
Inputs inputs;
inputs.emplace_back(std::vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11});
std::vector<float> expected_values_output{3, 2, 1, 7, 6, 5, 11, 10, 9};
std::vector<std::int64_t> expected_indices_output{3, 2, 1, 3, 2, 1, 3, 2, 1};
std::vector<std::shared_ptr<ngraph::runtime::Tensor>> result_tensors =
prepare_and_run(function, inputs, "${BACKEND_NAME}");
std::vector<float> values_output = read_vector<float>(result_tensors.at(0));
std::vector<std::int64_t> indices_output = read_vector<std::int64_t>(result_tensors.at(1));
EXPECT_TRUE(test::all_close_f(expected_values_output, values_output));
EXPECT_TRUE(test::all_close(expected_indices_output, indices_output));
}
......@@ -127,10 +127,11 @@ void init_real_tv(ngraph::runtime::Tensor* tv, std::default_random_engine& engin
void random_init(ngraph::runtime::Tensor* tv, std::default_random_engine& engine);
template <typename T, typename T1 = T>
std::vector<std::vector<T1>> execute(const std::shared_ptr<ngraph::Function>& function,
std::vector<std::vector<T>> args,
const std::string& backend_id)
template <typename T>
std::vector<std::shared_ptr<ngraph::runtime::Tensor>>
prepare_and_run(const std::shared_ptr<ngraph::Function>& function,
std::vector<std::vector<T>> args,
const std::string& backend_id)
{
auto backend = ngraph::runtime::Backend::create(backend_id);
......@@ -160,6 +161,16 @@ std::vector<std::vector<T1>> execute(const std::shared_ptr<ngraph::Function>& fu
auto handle = backend->compile(function);
backend->call_with_validate(handle, result_tensors, arg_tensors);
return result_tensors;
}
template <typename T, typename T1 = T>
std::vector<std::vector<T1>> execute(const std::shared_ptr<ngraph::Function>& function,
std::vector<std::vector<T>> args,
const std::string& backend_id)
{
std::vector<std::shared_ptr<ngraph::runtime::Tensor>> result_tensors =
prepare_and_run(function, args, backend_id);
std::vector<std::vector<T1>> result_vectors;
for (auto rt : result_tensors)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment