Commit d3f83f64 authored by Michał Karzyński's avatar Michał Karzyński Committed by Robert Kimball

[ONNX] Assert all op types supported (#1770)

* [ONNX] Assert all op types supported

* Apply clang-format

* Address code review comments

* Fix #include statements
parent 411f83e2
...@@ -25,6 +25,7 @@ add_library(onnx_import STATIC ...@@ -25,6 +25,7 @@ add_library(onnx_import STATIC
core/attribute.hpp core/attribute.hpp
core/graph.cpp core/graph.cpp
core/graph.hpp core/graph.hpp
core/model.cpp
core/model.hpp core/model.hpp
core/node.cpp core/node.cpp
core/node.hpp core/node.hpp
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#pragma once #pragma once
#include <onnx-ml.pb.h> #include <onnx-ml.pb.h>
#include "ngraph/except.hpp" #include "ngraph/except.hpp"
#include "tensor.hpp" #include "tensor.hpp"
......
...@@ -16,10 +16,11 @@ ...@@ -16,10 +16,11 @@
#pragma once #pragma once
#include <onnx-ml.pb.h>
#include <string> #include <string>
#include <vector> #include <vector>
#include <onnx-ml.pb.h>
#include "ngraph/op/parameter_vector.hpp" #include "ngraph/op/parameter_vector.hpp"
#include "value_info.hpp" #include "value_info.hpp"
......
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <ostream>
#include <set>
#include <onnx-ml.pb.h>
#include "assertion.hpp"
#include "model.hpp"
#include "ops_bridge.hpp"
namespace ngraph
{
namespace onnx_import
{
Model::Model(const onnx::ModelProto& model_proto)
: m_model_proto{&model_proto}
{
// Verify that the ONNX graph contains only nodes of supported op_type
assert_all_op_types_supported();
}
void Model::assert_all_op_types_supported()
{
std::set<std::string> unsupported_ops;
for (const auto& node_proto : get_graph().node())
{
std::string op_type = node_proto.op_type();
if (!ops_bridge::is_op_type_supported(op_type))
{
unsupported_ops.insert(op_type);
}
}
std::string unsupported_ops_str;
std::size_t index = 0;
for (const auto& op_type : unsupported_ops)
{
unsupported_ops_str += (index++ != 0 ? ", " : "");
unsupported_ops_str += op_type;
}
NGRAPH_ASSERT(unsupported_ops.empty()) << "unknown operations: " << unsupported_ops_str;
}
} // namespace onnx_import
} // namespace ngraph
...@@ -16,9 +16,10 @@ ...@@ -16,9 +16,10 @@
#pragma once #pragma once
#include <onnx-ml.pb.h>
#include <ostream> #include <ostream>
#include <onnx-ml.pb.h>
namespace ngraph namespace ngraph
{ {
namespace onnx_import namespace onnx_import
...@@ -27,10 +28,7 @@ namespace ngraph ...@@ -27,10 +28,7 @@ namespace ngraph
{ {
public: public:
Model() = delete; Model() = delete;
explicit Model(const onnx::ModelProto& model_proto) explicit Model(const onnx::ModelProto& model_proto);
: m_model_proto{&model_proto}
{
}
Model(Model&&) noexcept = default; Model(Model&&) noexcept = default;
Model(const Model&) = default; Model(const Model&) = default;
...@@ -48,6 +46,8 @@ namespace ngraph ...@@ -48,6 +46,8 @@ namespace ngraph
private: private:
const onnx::ModelProto* m_model_proto; const onnx::ModelProto* m_model_proto;
void assert_all_op_types_supported();
}; };
inline std::ostream& operator<<(std::ostream& outs, const Model& model) inline std::ostream& operator<<(std::ostream& outs, const Model& model)
......
...@@ -18,10 +18,10 @@ ...@@ -18,10 +18,10 @@
#include <string> #include <string>
#include "ngraph/node_vector.hpp"
#include <onnx-ml.pb.h> #include <onnx-ml.pb.h>
#include "ngraph/node_vector.hpp"
#include "attribute.hpp" #include "attribute.hpp"
#include "tensor.hpp" #include "tensor.hpp"
......
...@@ -18,11 +18,11 @@ ...@@ -18,11 +18,11 @@
#include <vector> #include <vector>
#include <onnx-ml.pb.h>
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp" #include "ngraph/type/element_type.hpp"
#include <onnx-ml.pb.h>
namespace ngraph namespace ngraph
{ {
namespace onnx_import namespace onnx_import
......
...@@ -16,13 +16,13 @@ ...@@ -16,13 +16,13 @@
#pragma once #pragma once
#include <onnx-ml.pb.h>
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/op/parameter_vector.hpp" #include "ngraph/op/parameter_vector.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp" #include "ngraph/type/element_type.hpp"
#include <onnx-ml.pb.h>
#include "node.hpp" #include "node.hpp"
#include "tensor.hpp" #include "tensor.hpp"
......
...@@ -15,14 +15,14 @@ ...@@ -15,14 +15,14 @@
//***************************************************************************** //*****************************************************************************
#include <memory> #include <memory>
#include <onnx-ml.pb.h> #include <onnx-ml.pb.h>
#include "ngraph/op/convert.hpp" #include "ngraph/op/convert.hpp"
#include "ngraph/type/element_type.hpp" #include "ngraph/type/element_type.hpp"
#include "exceptions.hpp"
#include "cast.hpp" #include "cast.hpp"
#include "exceptions.hpp"
namespace ngraph namespace ngraph
{ {
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <algorithm> #include <algorithm>
#include <functional> #include <functional>
#include <string>
#include "core/attribute.hpp" #include "core/attribute.hpp"
#include "op/abs.hpp" #include "op/abs.hpp"
...@@ -109,6 +110,11 @@ namespace ngraph ...@@ -109,6 +110,11 @@ namespace ngraph
return ops_bridge::get()(node); return ops_bridge::get()(node);
} }
static bool is_op_type_supported(const std::string& op_type)
{
return ops_bridge::get().is_op_type_supported_(op_type);
}
private: private:
std::map<std::string, std::function<NodeVector(const Node&)>> m_map; std::map<std::string, std::function<NodeVector(const Node&)>> m_map;
...@@ -204,6 +210,12 @@ namespace ngraph ...@@ -204,6 +210,12 @@ namespace ngraph
std::function<NodeVector(const Node&)> factory{it->second}; std::function<NodeVector(const Node&)> factory{it->second};
return factory(node); return factory(node);
} }
bool is_op_type_supported_(const std::string& op_type) const
{
auto it = m_map.find(op_type);
return !(it == m_map.end());
}
}; };
} // namespace detail } // namespace detail
...@@ -215,6 +227,11 @@ namespace ngraph ...@@ -215,6 +227,11 @@ namespace ngraph
return detail::ops_bridge::make_ng_nodes(node); return detail::ops_bridge::make_ng_nodes(node);
} }
bool is_op_type_supported(const std::string& op_type)
{
return detail::ops_bridge::is_op_type_supported(op_type);
}
} // namespace ops_bridge } // namespace ops_bridge
} // namespace onnx_import } // namespace onnx_import
......
...@@ -26,6 +26,7 @@ namespace ngraph ...@@ -26,6 +26,7 @@ namespace ngraph
namespace ops_bridge namespace ops_bridge
{ {
NodeVector make_ng_nodes(const onnx_import::Node&); NodeVector make_ng_nodes(const onnx_import::Node&);
bool is_op_type_supported(const std::string& op_type);
} }
} // namespace onnx_import } // namespace onnx_import
......
ngraph ONNXImporter:
$
ABmissing_op_node1"
FakeOpName
+
BCmissing_op_node2"AnotherFakeOpName

CX supported_op"Abs
test_graphZ
A

b
X

B
\ No newline at end of file
...@@ -1276,3 +1276,24 @@ TEST(onnx, model_thresholded_relu) ...@@ -1276,3 +1276,24 @@ TEST(onnx, model_thresholded_relu)
Outputs outputs{execute(function, inputs, "INTERPRETER")}; Outputs outputs{execute(function, inputs, "INTERPRETER")};
EXPECT_TRUE(test::all_close_f(expected_output.front(), outputs.front())); EXPECT_TRUE(test::all_close_f(expected_output.front(), outputs.front()));
} }
TEST(onnx, model_unsupported_op)
{
try
{
onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/unsupported_op.onnx"));
FAIL() << "Expected ngraph::ngraph_error";
}
catch (ngraph::ngraph_error const& err)
{
std::string what{err.what()};
EXPECT_NE(what.find("unknown operations"), std::string::npos);
EXPECT_NE(what.find("FakeOpName"), std::string::npos);
EXPECT_NE(what.find("AnotherFakeOpName"), std::string::npos);
}
catch (...)
{
FAIL() << "Expected ngraph::ngraph_error";
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment