Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
49a32b14
Commit
49a32b14
authored
Jan 16, 2019
by
Adam Rogowiec
Committed by
Michał Karzyński
Jan 16, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[ONNX] Expose API to check whether ONNX Op is supported. (#2299)
parent
1db9707f
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
109 additions
and
8 deletions
+109
-8
onnx.cpp
src/ngraph/frontend/onnx_import/onnx.cpp
+8
-0
onnx.hpp
src/ngraph/frontend/onnx_import/onnx.hpp
+12
-0
ops_bridge.cpp
src/ngraph/frontend/onnx_import/ops_bridge.cpp
+46
-8
ops_bridge.hpp
src/ngraph/frontend/onnx_import/ops_bridge.hpp
+12
-0
onnx_import.cpp
test/onnx_import.cpp
+31
-0
No files found.
src/ngraph/frontend/onnx_import/onnx.cpp
View file @
49a32b14
...
...
@@ -99,6 +99,14 @@ namespace ngraph
return
op_list
;
}
bool
is_operator_supported
(
const
std
::
string
&
op_name
,
std
::
int64_t
version
,
const
std
::
string
&
domain
)
{
return
OperatorsBridge
::
is_operator_registered
(
op_name
,
version
,
domain
==
"ai.onnx"
?
""
:
domain
);
}
}
// namespace onnx_import
}
// namespace ngraph
src/ngraph/frontend/onnx_import/onnx.hpp
View file @
49a32b14
...
...
@@ -52,6 +52,18 @@ namespace ngraph
std
::
set
<
std
::
string
>
get_supported_operators
(
std
::
int64_t
version
,
const
std
::
string
&
domain
);
/// \brief Determines whether ONNX operator is supported.
///
/// \param[in] op_name The ONNX operator name.
/// \param[in] version The ONNX operator set version.
/// \param[in] domain The domain the ONNX operator is registered to.
///
/// \return True if operator is supported, False otherwise.
///
bool
is_operator_supported
(
const
std
::
string
&
op_name
,
std
::
int64_t
version
,
const
std
::
string
&
domain
=
"ai.onnx"
);
/// \brief Convert an ONNX model to nGraph function
/// The function translated serialized ONNX model to nGraph function. The serialized
/// ONNX model is read from input stream.
...
...
src/ngraph/frontend/onnx_import/ops_bridge.cpp
View file @
49a32b14
...
...
@@ -21,6 +21,7 @@
#include <unordered_map>
#include "core/attribute.hpp"
#include "ngraph/log.hpp"
#include "op/abs.hpp"
#include "op/acos.hpp"
#include "op/add.hpp"
...
...
@@ -102,20 +103,19 @@ namespace ngraph
{
namespace
detail
{
const
Operator
&
find
(
const
std
::
string
&
name
,
std
::
int64_t
version
,
const
std
::
string
&
domain
,
const
std
::
map
<
std
::
int64_t
,
Operator
>&
map
)
const
std
::
map
<
std
::
int64_t
,
Operator
>::
const_iterator
find
(
std
::
int64_t
version
,
const
std
::
map
<
std
::
int64_t
,
Operator
>&
map
)
{
std
::
map
<
std
::
int64_t
,
Operator
>::
const_iterator
it
{};
while
(
version
>
0
)
{
const
auto
it
=
map
.
find
(
version
--
);
it
=
map
.
find
(
version
--
);
if
(
it
!=
std
::
end
(
map
))
{
return
it
->
second
;
return
it
;
}
}
throw
error
::
UnsupportedVersion
{
name
,
version
,
domain
}
;
return
it
;
}
}
...
...
@@ -136,13 +136,51 @@ namespace ngraph
{
throw
error
::
UnknownDomain
{
domain
};
}
if
(
version
>
OperatorsBridge
::
LATEST_SUPPORTED_OPSET_VERSION
)
{
NGRAPH_WARN
<<
"Currently operator set version: "
<<
version
<<
" is unsupported."
<<
" Falling back to: "
<<
OperatorsBridge
::
LATEST_SUPPORTED_OPSET_VERSION
;
}
for
(
const
auto
&
op
:
dm
->
second
)
{
result
.
emplace
(
op
.
first
,
detail
::
find
(
op
.
first
,
version
,
domain
,
op
.
second
));
const
auto
&
it
=
detail
::
find
(
version
,
op
.
second
);
if
(
it
==
std
::
end
(
op
.
second
))
{
throw
error
::
UnsupportedVersion
{
op
.
first
,
version
,
domain
};
}
result
.
emplace
(
op
.
first
,
it
->
second
);
}
return
result
;
}
bool
OperatorsBridge
::
_is_operator_registered
(
const
std
::
string
&
name
,
std
::
int64_t
version
,
const
std
::
string
&
domain
)
{
// search for domain
auto
dm_map
=
m_map
.
find
(
domain
);
if
(
dm_map
==
std
::
end
(
m_map
))
{
return
false
;
}
// search for name
auto
op_map
=
dm_map
->
second
.
find
(
name
);
if
(
op_map
==
std
::
end
(
dm_map
->
second
))
{
return
false
;
}
if
(
detail
::
find
(
version
,
op_map
->
second
)
!=
std
::
end
(
op_map
->
second
))
{
return
true
;
}
else
{
return
false
;
}
}
#define REGISTER_OPERATOR(name_, ver_, fn_) \
m_map[""][name_].emplace(ver_, std::bind(op::set_##ver_::fn_, std::placeholders::_1))
...
...
src/ngraph/frontend/onnx_import/ops_bridge.hpp
View file @
49a32b14
...
...
@@ -62,6 +62,8 @@ namespace ngraph
class
OperatorsBridge
{
public
:
static
constexpr
const
int
LATEST_SUPPORTED_OPSET_VERSION
=
ONNX_OPSET_VERSION
;
OperatorsBridge
(
const
OperatorsBridge
&
)
=
delete
;
OperatorsBridge
&
operator
=
(
const
OperatorsBridge
&
)
=
delete
;
OperatorsBridge
(
OperatorsBridge
&&
)
=
delete
;
...
...
@@ -80,6 +82,13 @@ namespace ngraph
instance
().
_register_operator
(
name
,
version
,
domain
,
std
::
move
(
fn
));
}
static
bool
is_operator_registered
(
const
std
::
string
&
name
,
std
::
int64_t
version
,
const
std
::
string
&
domain
)
{
return
instance
().
_is_operator_registered
(
name
,
version
,
domain
);
}
private
:
std
::
unordered_map
<
std
::
string
,
std
::
unordered_map
<
std
::
string
,
std
::
map
<
std
::
int64_t
,
Operator
>>>
...
...
@@ -98,6 +107,9 @@ namespace ngraph
const
std
::
string
&
domain
,
Operator
fn
);
OperatorSet
_get_operator_set
(
std
::
int64_t
version
,
const
std
::
string
&
domain
);
bool
_is_operator_registered
(
const
std
::
string
&
name
,
std
::
int64_t
version
,
const
std
::
string
&
domain
);
};
}
// namespace onnx_import
...
...
test/onnx_import.cpp
View file @
49a32b14
...
...
@@ -1678,3 +1678,34 @@ TEST(onnx, model_argmin_int32)
execute
<
std
::
int32_t
,
std
::
int64_t
>
(
function
,
inputs
,
"INTERPRETER"
)};
EXPECT_TRUE
(
test
::
all_close
(
expected_output
.
front
(),
outputs
.
front
()));
}
TEST
(
onnx
,
model_is_op_supported
)
{
// Simple case
EXPECT_TRUE
(
onnx_import
::
is_operator_supported
(
"Sum"
,
1
,
"ai.onnx"
));
// With fallback
EXPECT_TRUE
(
onnx_import
::
is_operator_supported
(
"Sum"
,
100
,
"ai.onnx"
));
// Different opset versions
EXPECT_TRUE
(
onnx_import
::
is_operator_supported
(
"Add"
,
1
,
"ai.onnx"
));
EXPECT_TRUE
(
onnx_import
::
is_operator_supported
(
"Add"
,
7
,
"ai.onnx"
));
// Default domain name
EXPECT_TRUE
(
onnx_import
::
is_operator_supported
(
"Sum"
,
1
));
// Unregistered operator
EXPECT_FALSE
(
onnx_import
::
is_operator_supported
(
"DummyOp"
,
1
));
EXPECT_FALSE
(
onnx_import
::
is_operator_supported
(
"DummyOp"
,
1
,
"ai.onnx"
));
EXPECT_FALSE
(
onnx_import
::
is_operator_supported
(
"DummyOp"
,
10
,
"ai.onnx"
));
// Operator with bad domain name
EXPECT_FALSE
(
onnx_import
::
is_operator_supported
(
"Sum"
,
1
,
"bad.domain"
));
// Registered custom operator
onnx_import
::
register_operator
(
"AddQ"
,
1
,
"com.intel.ai"
,
[](
const
onnx_import
::
Node
&
node
)
->
NodeVector
{
NodeVector
ng_inputs
{
node
.
get_ng_inputs
()};
return
{
std
::
make_shared
<
ngraph
::
op
::
Add
>
(
ng_inputs
.
at
(
0
),
ng_inputs
.
at
(
1
))};
});
EXPECT_TRUE
(
onnx_import
::
is_operator_supported
(
"AddQ"
,
1
,
"com.intel.ai"
));
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment