Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
1f350378
Commit
1f350378
authored
Feb 01, 2019
by
tsocha
Committed by
Michał Karzyński
Feb 01, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[ONNX] Enable deselected supported opset domain when needed. (#2350)
parent
676f8d36
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
165 additions
and
36 deletions
+165
-36
attribute.cpp
src/ngraph/frontend/onnx_import/core/attribute.cpp
+2
-2
attribute.hpp
src/ngraph/frontend/onnx_import/core/attribute.hpp
+2
-2
graph.cpp
src/ngraph/frontend/onnx_import/core/graph.cpp
+46
-14
graph.hpp
src/ngraph/frontend/onnx_import/core/graph.hpp
+2
-2
model.cpp
src/ngraph/frontend/onnx_import/core/model.cpp
+23
-2
model.hpp
src/ngraph/frontend/onnx_import/core/model.hpp
+9
-0
onnx.cpp
src/ngraph/frontend/onnx_import/onnx.cpp
+2
-1
ops_bridge.cpp
src/ngraph/frontend/onnx_import/ops_bridge.cpp
+18
-7
ops_bridge.hpp
src/ngraph/frontend/onnx_import/ops_bridge.hpp
+20
-4
missing_op_domain.onnx
test/models/onnx/missing_op_domain.onnx
+17
-0
onnx_import.in.cpp
test/onnx_import.in.cpp
+24
-2
No files found.
src/ngraph/frontend/onnx_import/core/attribute.cpp
View file @
1f350378
...
...
@@ -22,7 +22,7 @@ namespace ngraph
{
namespace
onnx_import
{
std
::
vector
<
Graph
>
Attribute
::
get_graph_array
(
const
Model
&
model
)
const
std
::
vector
<
Graph
>
Attribute
::
get_graph_array
(
Model
&
model
)
const
{
std
::
vector
<
Graph
>
result
;
for
(
const
auto
&
graph
:
m_attribute_proto
->
graphs
())
...
...
@@ -32,7 +32,7 @@ namespace ngraph
return
result
;
}
Graph
Attribute
::
get_graph
(
const
Model
&
model
)
const
Graph
Attribute
::
get_graph
(
Model
&
model
)
const
{
return
Graph
{
m_attribute_proto
->
g
(),
model
};
}
...
...
src/ngraph/frontend/onnx_import/core/attribute.hpp
View file @
1f350378
...
...
@@ -278,7 +278,7 @@ namespace ngraph
float
get_float
()
const
{
return
m_attribute_proto
->
f
();
}
int64_t
get_integer
()
const
{
return
m_attribute_proto
->
i
();
}
const
std
::
string
&
get_string
()
const
{
return
m_attribute_proto
->
s
();
}
Graph
get_graph
(
const
Model
&
)
const
;
Graph
get_graph
(
Model
&
)
const
;
std
::
vector
<
Tensor
>
get_tensor_array
()
const
{
...
...
@@ -303,7 +303,7 @@ namespace ngraph
std
::
end
(
m_attribute_proto
->
strings
())};
}
std
::
vector
<
Graph
>
get_graph_array
(
const
Model
&
)
const
;
std
::
vector
<
Graph
>
get_graph_array
(
Model
&
)
const
;
/* explicit */
operator
onnx
::
AttributeProto_AttributeType
()
const
{
...
...
src/ngraph/frontend/onnx_import/core/graph.cpp
View file @
1f350378
...
...
@@ -14,6 +14,7 @@
// limitations under the License.
//*****************************************************************************
#include <functional>
#include <set>
#include "graph.hpp"
...
...
@@ -25,26 +26,40 @@ namespace ngraph
{
namespace
detail
{
std
::
string
to_string
(
const
std
::
set
<
std
::
string
>&
set
)
static
std
::
string
to_string
(
const
std
::
map
<
std
::
string
,
std
::
reference_wrapper
<
const
onnx
::
NodeProto
>>&
map
)
{
std
::
string
result
;
for
(
auto
it
=
std
::
begin
(
set
);
it
!=
std
::
end
(
set
);
++
it
)
for
(
auto
it
=
std
::
begin
(
map
);
it
!=
std
::
end
(
map
);
++
it
)
{
result
+=
(
it
!=
std
::
begin
(
set
)
?
", "
:
""
)
+
*
i
t
;
result
+=
(
it
!=
std
::
begin
(
map
)
?
", "
:
""
)
+
it
->
firs
t
;
}
return
result
;
}
inline
std
::
string
to_string
(
const
onnx
::
NodeProto
&
node_proto
)
static
std
::
string
get_node_domain
(
const
onnx
::
NodeProto
&
node_proto
)
{
return
(
node_proto
.
domain
().
empty
()
?
""
:
node_proto
.
domain
()
+
"."
)
+
node_proto
.
op_type
();
return
(
node_proto
.
domain
().
empty
()
?
""
:
node_proto
.
domain
());
}
}
Graph
::
Graph
(
const
onnx
::
GraphProto
&
graph_proto
,
const
Model
&
model
,
const
Weights
&
weights
)
/// \brief Gets the operator represented by provided node unique identificator.
///
/// \param[in] node_proto The node protobuf representation object.
///
/// \note The operator is uniquely identified by the tuple (domain, op_type,
/// since_version). The first two elements are stored in NodeProto object,
/// thus we use only them.
///
/// \return The unique identificator.
///
static
std
::
string
get_op_domain_and_name
(
const
onnx
::
NodeProto
&
node_proto
)
{
std
::
string
domain
=
get_node_domain
(
node_proto
);
return
(
domain
.
empty
()
?
""
:
domain
+
"."
)
+
node_proto
.
op_type
();
}
}
// namespace detail
Graph
::
Graph
(
const
onnx
::
GraphProto
&
graph_proto
,
Model
&
model
,
const
Weights
&
weights
)
:
m_graph_proto
{
&
graph_proto
}
,
m_model
{
&
model
}
{
...
...
@@ -70,17 +85,34 @@ namespace ngraph
}
// Verify that ONNX graph contains only nodes of available operator types
std
::
set
<
std
::
string
>
unknown_operator_type
s
;
std
::
map
<
std
::
string
,
std
::
reference_wrapper
<
const
onnx
::
NodeProto
>>
unknown_operator
s
;
for
(
const
auto
&
node_proto
:
m_graph_proto
->
node
())
{
if
(
!
m_model
->
is_operator_available
(
node_proto
))
{
unknown_operator_types
.
emplace
(
detail
::
to_string
(
node_proto
));
unknown_operators
.
emplace
(
detail
::
get_op_domain_and_name
(
node_proto
),
node_proto
);
// Try adding missing domain
m_model
->
enable_opset_domain
(
detail
::
get_node_domain
(
node_proto
));
}
}
// Reverify wheter we still have any unavailable operators.
auto
it
=
std
::
begin
(
unknown_operators
);
while
(
it
!=
std
::
end
(
unknown_operators
))
{
if
(
m_model
->
is_operator_available
(
it
->
second
))
{
it
=
unknown_operators
.
erase
(
it
);
}
else
{
it
++
;
}
}
NGRAPH_ASSERT
(
unknown_operator
_types
.
empty
())
<<
"unknown operations: "
<<
detail
::
to_string
(
unknown_operator_type
s
);
NGRAPH_ASSERT
(
unknown_operator
s
.
empty
())
<<
"unknown operations: "
<<
detail
::
to_string
(
unknown_operator
s
);
// Process ONNX graph nodes, convert to nGraph nodes
for
(
const
auto
&
node_proto
:
m_graph_proto
->
node
())
...
...
src/ngraph/frontend/onnx_import/core/graph.hpp
View file @
1f350378
...
...
@@ -33,7 +33,7 @@ namespace ngraph
class
Graph
{
public
:
Graph
(
const
onnx
::
GraphProto
&
proto
,
const
Model
&
model
,
const
Weights
&
weights
=
{});
Graph
(
const
onnx
::
GraphProto
&
proto
,
Model
&
model
,
const
Weights
&
weights
=
{});
const
std
::
vector
<
Node
>&
get_nodes
()
const
{
return
m_nodes
;
}
const
std
::
vector
<
ValueInfo
>&
get_inputs
()
const
{
return
m_inputs
;
}
...
...
@@ -59,7 +59,7 @@ namespace ngraph
ParameterVector
m_parameters
;
std
::
map
<
std
::
string
,
std
::
shared_ptr
<
ngraph
::
Node
>>
m_ng_node_cache
;
std
::
map
<
std
::
string
,
Tensor
>
m_initializers
;
const
Model
*
m_model
;
Model
*
m_model
;
};
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
outs
,
const
Graph
&
graph
)
...
...
src/ngraph/frontend/onnx_import/core/model.cpp
View file @
1f350378
...
...
@@ -17,6 +17,7 @@
#include <onnx-ml.pb.h>
#include "model.hpp"
#include "ngraph/log.hpp"
#include "ops_bridge.hpp"
namespace
ngraph
...
...
@@ -33,14 +34,14 @@ namespace ngraph
{
m_opset
.
emplace
(
id
.
domain
(),
OperatorsBridge
::
get_operator_set
(
id
.
version
(),
(
id
.
domain
()
==
"ai.onnx"
?
""
:
id
.
domain
()
)));
(
id
.
domain
()
==
"ai.onnx"
?
""
:
id
.
domain
()),
id
.
version
(
)));
}
// onnx.proto(.3): the empty string ("") for domain or absence of opset_import field
// implies the operator set that is defined as part of the ONNX specification.
const
auto
dm
=
m_opset
.
find
(
""
);
if
(
dm
==
std
::
end
(
m_opset
))
{
m_opset
.
emplace
(
""
,
OperatorsBridge
::
get_operator_set
(
ONNX_OPSET_VERSION
,
""
));
m_opset
.
emplace
(
""
,
OperatorsBridge
::
get_operator_set
(
""
,
ONNX_OPSET_VERSION
));
}
}
...
...
@@ -71,6 +72,26 @@ namespace ngraph
return
(
op
!=
std
::
end
(
dm
->
second
));
}
void
Model
::
enable_opset_domain
(
const
std
::
string
&
domain
)
{
// There is no need to 'update' already enabled domain.
// Since this function may be called only during model import,
// (maybe multiple times) the registered domain opset won't differ
// between subsequent calls.
if
(
m_opset
.
find
(
domain
)
==
std
::
end
(
m_opset
))
{
OperatorSet
opset
{
OperatorsBridge
::
get_operator_set
(
domain
)};
if
(
opset
.
empty
())
{
NGRAPH_WARN
<<
"Couldn't enable domain: "
<<
domain
<<
" since it hasn't any registered operators."
;
return
;
}
m_opset
.
emplace
(
domain
,
opset
);
}
}
}
// namespace onnx_import
}
// namespace ngraph
src/ngraph/frontend/onnx_import/core/model.hpp
View file @
1f350378
...
...
@@ -61,6 +61,15 @@ namespace ngraph
/// \return `true` if the operator is available, otherwise it returns `false`.
bool
is_operator_available
(
const
onnx
::
NodeProto
&
node_proto
)
const
;
/// \brief Enable operators from provided domain to use by this model.
///
/// \note This function makes visible all currently registered in provided domain
/// operators for use in this model.
///
/// \param[in] domain The domain name.
///
void
enable_opset_domain
(
const
std
::
string
&
domain
);
private
:
const
onnx
::
ModelProto
*
m_model_proto
;
std
::
unordered_map
<
std
::
string
,
OperatorSet
>
m_opset
;
...
...
src/ngraph/frontend/onnx_import/onnx.cpp
View file @
1f350378
...
...
@@ -90,7 +90,8 @@ namespace ngraph
std
::
set
<
std
::
string
>
get_supported_operators
(
std
::
int64_t
version
,
const
std
::
string
&
domain
)
{
OperatorSet
op_set
{
OperatorsBridge
::
get_operator_set
(
version
,
domain
)};
OperatorSet
op_set
{
OperatorsBridge
::
get_operator_set
(
domain
==
"ai.onnx"
?
""
:
domain
,
version
)};
std
::
set
<
std
::
string
>
op_list
{};
for
(
const
auto
&
op
:
op_set
)
{
...
...
src/ngraph/frontend/onnx_import/ops_bridge.cpp
View file @
1f350378
...
...
@@ -110,6 +110,11 @@ namespace ngraph
find
(
std
::
int64_t
version
,
const
std
::
map
<
std
::
int64_t
,
Operator
>&
map
)
{
std
::
map
<
std
::
int64_t
,
Operator
>::
const_iterator
it
{};
// Get the latest version.
if
(
version
==
-
1
)
{
return
map
.
empty
()
?
std
::
end
(
map
)
:
--
std
::
end
(
map
);
}
while
(
version
>
0
)
{
it
=
map
.
find
(
version
--
);
...
...
@@ -127,23 +132,29 @@ namespace ngraph
const
std
::
string
&
domain
,
Operator
fn
)
{
m_map
[
domain
][
name
].
emplace
(
version
,
std
::
move
(
fn
));
auto
result
=
m_map
[
domain
][
name
].
emplace
(
version
,
std
::
move
(
fn
));
if
(
result
.
second
)
{
NGRAPH_WARN
<<
"Overwriting existing operator: "
<<
domain
+
"."
+
name
+
":"
+
std
::
to_string
(
version
);
}
}
OperatorSet
OperatorsBridge
::
_get_operator_set
(
std
::
int64_t
versio
n
,
const
std
::
string
&
domai
n
)
OperatorSet
OperatorsBridge
::
_get_operator_set
(
const
std
::
string
&
domai
n
,
std
::
int64_t
versio
n
)
{
OperatorSet
result
;
auto
dm
=
m_map
.
find
(
domain
);
if
(
dm
==
std
::
end
(
m_map
))
{
throw
error
::
UnknownDomain
{
domain
};
}
if
(
version
>
OperatorsBridge
::
LATEST_SUPPORTED
_OPSET_VERSION
)
if
(
domain
==
""
&&
version
>
OperatorsBridge
::
LATEST_SUPPORTED_ONNX
_OPSET_VERSION
)
{
NGRAPH_WARN
<<
"Currently
operator set version: "
<<
version
<<
" is unsupported."
<<
" Falling back to: "
<<
OperatorsBridge
::
LATEST_SUPPORTED_OPSET_VERSION
;
NGRAPH_WARN
<<
"Currently
ONNX operator set version: "
<<
version
<<
"
is unsupported.
Falling back to: "
<<
OperatorsBridge
::
LATEST_SUPPORTED_O
NNX_O
PSET_VERSION
;
}
for
(
const
auto
&
op
:
dm
->
second
)
{
...
...
src/ngraph/frontend/onnx_import/ops_bridge.hpp
View file @
1f350378
...
...
@@ -62,16 +62,17 @@ namespace ngraph
class
OperatorsBridge
{
public
:
static
constexpr
const
int
LATEST_SUPPORTED_OPSET_VERSION
=
ONNX_OPSET_VERSION
;
static
constexpr
const
int
LATEST_SUPPORTED_O
NNX_O
PSET_VERSION
=
ONNX_OPSET_VERSION
;
OperatorsBridge
(
const
OperatorsBridge
&
)
=
delete
;
OperatorsBridge
&
operator
=
(
const
OperatorsBridge
&
)
=
delete
;
OperatorsBridge
(
OperatorsBridge
&&
)
=
delete
;
OperatorsBridge
&
operator
=
(
OperatorsBridge
&&
)
=
delete
;
static
OperatorSet
get_operator_set
(
std
::
int64_t
version
,
const
std
::
string
&
domain
)
static
OperatorSet
get_operator_set
(
const
std
::
string
&
domain
,
std
::
int64_t
version
=
-
1
)
{
return
instance
().
_get_operator_set
(
version
,
domai
n
);
return
instance
().
_get_operator_set
(
domain
,
versio
n
);
}
static
void
register_operator
(
const
std
::
string
&
name
,
...
...
@@ -90,6 +91,20 @@ namespace ngraph
}
private
:
// Registered operators structure
// {
// domain_1: {
// op_type_1: {
// version_1: {func_handle},
// version_2: {func_handle},
// ...
// },
// op_type_2: { ... }
// ...
// },
// domain_2: { ... },
// ...
// }
std
::
unordered_map
<
std
::
string
,
std
::
unordered_map
<
std
::
string
,
std
::
map
<
std
::
int64_t
,
Operator
>>>
m_map
;
...
...
@@ -106,7 +121,8 @@ namespace ngraph
std
::
int64_t
version
,
const
std
::
string
&
domain
,
Operator
fn
);
OperatorSet
_get_operator_set
(
std
::
int64_t
version
,
const
std
::
string
&
domain
);
OperatorSet
_get_operator_set
(
const
std
::
string
&
domain
,
std
::
int64_t
version
);
bool
_is_operator_registered
(
const
std
::
string
&
name
,
std
::
int64_t
version
,
const
std
::
string
&
domain
);
...
...
test/models/onnx/missing_op_domain.onnx
0 → 100644
View file @
1f350378
ONNXnGraphImporter:o
A
BC" CustomAdd: custom.op compute_graphZ
A
Z
B
b
C
B
\ No newline at end of file
test/onnx_import.in.cpp
View file @
1f350378
...
...
@@ -1820,6 +1820,29 @@ TEST(onnx_${BACKEND_NAME}, model_space_to_depth_no_blocksize)
std
::
runtime_error
);
}
TEST
(
onnx_
$
{
BACKEND_NAME
},
model_missing_op_domain
)
{
onnx_import
::
register_operator
(
"CustomAdd"
,
1
,
"custom.op"
,
[](
const
onnx_import
::
Node
&
node
)
->
NodeVector
{
NodeVector
ng_inputs
{
node
.
get_ng_inputs
()};
return
{
std
::
make_shared
<
ngraph
::
op
::
Add
>
(
ng_inputs
.
at
(
0
),
ng_inputs
.
at
(
1
))};
});
EXPECT_TRUE
(
onnx_import
::
is_operator_supported
(
"CustomAdd"
,
1
,
"custom.op"
));
auto
function
=
onnx_import
::
import_onnx_model
(
file_util
::
path_join
(
SERIALIZED_ZOO
,
"onnx/missing_op_domain.onnx"
));
Inputs
inputs
;
inputs
.
emplace_back
(
std
::
vector
<
float
>
{
0.
f
,
1.
f
,
2.
f
,
3.
f
});
inputs
.
emplace_back
(
std
::
vector
<
float
>
{
0.
f
,
1.
f
,
2.
f
,
3.
f
});
Outputs
expected_output
{
std
::
vector
<
float
>
{
0.
f
,
2.
f
,
4.
f
,
6.
f
}};
Outputs
outputs
{
execute
(
function
,
inputs
,
"${BACKEND_NAME}"
)};
EXPECT_TRUE
(
test
::
all_close_f
(
expected_output
.
front
(),
outputs
.
front
()));
}
TEST
(
onnx_
$
{
BACKEND_NAME
},
model_top_k
)
{
auto
function
=
...
...
@@ -1839,4 +1862,4 @@ TEST(onnx_${BACKEND_NAME}, model_top_k)
EXPECT_TRUE
(
test
::
all_close_f
(
expected_values_output
,
values_output
));
EXPECT_TRUE
(
test
::
all_close
(
expected_indices_output
,
indices_output
));
}
\ No newline at end of file
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment