Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
c85ff3b8
Unverified
Commit
c85ff3b8
authored
Oct 18, 2018
by
Artur Wojcik
Committed by
GitHub
Oct 18, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
onnx: flatten operatos bridge hierarchy (#1846)
Signed-off-by:
Artur Wojcik
<
artur.wojcik@intel.com
>
parent
3167b167
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
70 additions
and
89 deletions
+70
-89
onnx.cpp
src/ngraph/frontend/onnx_import/onnx.cpp
+1
-1
ops_bridge.cpp
src/ngraph/frontend/onnx_import/ops_bridge.cpp
+11
-85
ops_bridge.hpp
src/ngraph/frontend/onnx_import/ops_bridge.hpp
+58
-3
No files found.
src/ngraph/frontend/onnx_import/onnx.cpp
View file @
c85ff3b8
...
...
@@ -62,7 +62,7 @@ namespace ngraph
std
::
vector
<
std
::
shared_ptr
<
Function
>>
output_functions
;
Model
model
{
model_proto
};
Graph
graph
{
model_proto
.
graph
(),
ops_b
ridge
::
get_operator_set
(
model
.
get_opset_version
())};
OperatorsB
ridge
::
get_operator_set
(
model
.
get_opset_version
())};
for
(
const
auto
&
output
:
graph
.
get_outputs
())
{
output_functions
.
emplace_back
(
std
::
make_shared
<
Function
>
(
...
...
src/ngraph/frontend/onnx_import/ops_bridge.cpp
View file @
c85ff3b8
...
...
@@ -87,69 +87,7 @@ namespace ngraph
{
namespace
onnx_import
{
namespace
detail
{
namespace
error
{
struct
UnknownOperator
:
ngraph_error
{
explicit
UnknownOperator
(
const
std
::
string
&
op_type
)
:
ngraph_error
{
"unknown operator:
\"
"
+
op_type
+
"
\"
"
}
{
}
};
struct
UnsupportedVersion
:
ngraph_error
{
explicit
UnsupportedVersion
(
std
::
int64_t
version
)
:
ngraph_error
{
"unsupported operator set version: "
+
std
::
to_string
(
version
)}
{
}
};
}
// namespace error
class
OperatorsBridge
{
public
:
OperatorsBridge
(
const
OperatorsBridge
&
)
=
delete
;
OperatorsBridge
&
operator
=
(
const
OperatorsBridge
&
)
=
delete
;
OperatorsBridge
(
OperatorsBridge
&&
)
=
delete
;
OperatorsBridge
&
operator
=
(
OperatorsBridge
&&
)
=
delete
;
static
const
OperatorSet
&
get_operator_set
(
std
::
int64_t
version
)
{
return
instance
().
get_operator_set_version
(
version
);
}
private
:
std
::
unordered_map
<
std
::
string
,
std
::
map
<
std
::
int64_t
,
std
::
function
<
NodeVector
(
const
Node
&
)
>>>
m_map
;
static
const
OperatorsBridge
&
instance
()
{
static
OperatorsBridge
instance
;
return
instance
;
}
const
Operator
&
get_operator
(
const
std
::
string
&
name
,
std
::
int64_t
version
)
const
{
auto
op
=
m_map
.
find
(
name
);
if
(
op
==
std
::
end
(
m_map
))
{
throw
error
::
UnknownOperator
{
name
};
}
auto
it
=
op
->
second
.
find
(
version
);
if
(
it
==
std
::
end
(
op
->
second
))
{
throw
error
::
UnsupportedVersion
{
version
};
}
return
it
->
second
;
}
const
OperatorSet
&
get_operator_set_version_1
()
const
const
OperatorSet
&
OperatorsBridge
::
get_operator_set_version_1
()
const
{
static
OperatorSet
operator_set
;
if
(
operator_set
.
empty
())
...
...
@@ -168,7 +106,7 @@ namespace ngraph
return
operator_set
;
}
const
OperatorSet
&
get_operator_set_version_2
()
const
const
OperatorSet
&
OperatorsBridge
::
get_operator_set_version_2
()
const
{
static
OperatorSet
operator_set
;
if
(
operator_set
.
empty
())
...
...
@@ -178,7 +116,7 @@ namespace ngraph
return
operator_set
;
}
const
OperatorSet
&
get_operator_set_version_3
()
const
const
OperatorSet
&
OperatorsBridge
::
get_operator_set_version_3
()
const
{
static
OperatorSet
operator_set
;
if
(
operator_set
.
empty
())
...
...
@@ -188,7 +126,7 @@ namespace ngraph
return
operator_set
;
}
const
OperatorSet
&
get_operator_set_version_4
()
const
const
OperatorSet
&
OperatorsBridge
::
get_operator_set_version_4
()
const
{
static
OperatorSet
operator_set
;
if
(
operator_set
.
empty
())
...
...
@@ -198,7 +136,7 @@ namespace ngraph
return
operator_set
;
}
const
OperatorSet
&
get_operator_set_version_5
()
const
const
OperatorSet
&
OperatorsBridge
::
get_operator_set_version_5
()
const
{
static
OperatorSet
operator_set
;
if
(
operator_set
.
empty
())
...
...
@@ -208,7 +146,7 @@ namespace ngraph
return
operator_set
;
}
const
OperatorSet
&
get_operator_set_version_6
()
const
const
OperatorSet
&
OperatorsBridge
::
get_operator_set_version_6
()
const
{
static
OperatorSet
operator_set
;
if
(
operator_set
.
empty
())
...
...
@@ -218,7 +156,7 @@ namespace ngraph
return
operator_set
;
}
const
OperatorSet
&
get_operator_set_version_7
()
const
const
OperatorSet
&
OperatorsBridge
::
get_operator_set_version_7
()
const
{
static
OperatorSet
operator_set
;
if
(
operator_set
.
empty
())
...
...
@@ -228,7 +166,7 @@ namespace ngraph
return
operator_set
;
}
const
OperatorSet
&
get_operator_set_version_8
()
const
const
OperatorSet
&
OperatorsBridge
::
get_operator_set_version_8
()
const
{
static
OperatorSet
operator_set
;
if
(
operator_set
.
empty
())
...
...
@@ -238,7 +176,7 @@ namespace ngraph
return
operator_set
;
}
const
OperatorSet
&
get_operator_set_version_9
()
const
const
OperatorSet
&
OperatorsBridge
::
get_operator_set_version_9
()
const
{
static
OperatorSet
operator_set
;
if
(
operator_set
.
empty
())
...
...
@@ -258,7 +196,7 @@ namespace ngraph
#define DEFAULT_OPERATOR_SET() return OPERATOR_SET_NAME_HELPER(ONNX_OPSET_VERSION)
const
OperatorSet
&
get_operator_set_version
(
std
::
int64_t
version
)
const
const
OperatorSet
&
OperatorsBridge
::
get_operator_set_version
(
std
::
int64_t
version
)
const
{
switch
(
version
)
{
...
...
@@ -278,7 +216,7 @@ namespace ngraph
#define REGISTER_OPERATOR(name_, version_, fn_) \
m_map[name_].emplace(version_, std::bind(op::set_##version_::fn_, std::placeholders::_1))
OperatorsBridge
()
OperatorsBridge
::
OperatorsBridge
()
{
REGISTER_OPERATOR
(
"Abs"
,
1
,
abs
);
REGISTER_OPERATOR
(
"Add"
,
1
,
add
);
...
...
@@ -351,18 +289,6 @@ namespace ngraph
REGISTER_OPERATOR
(
"Unsqueeze"
,
1
,
unsqueeze
);
REGISTER_OPERATOR
(
"Xor"
,
1
,
logical_xor
);
}
};
}
// namespace detail
namespace
ops_bridge
{
const
OperatorSet
&
get_operator_set
(
std
::
int64_t
version
)
{
return
detail
::
OperatorsBridge
::
get_operator_set
(
version
);
}
}
// namespace ops_bridge
}
// namespace onnx_import
...
...
src/ngraph/frontend/onnx_import/ops_bridge.hpp
View file @
c85ff3b8
...
...
@@ -17,6 +17,11 @@
#pragma once
#include <cstdint>
#include <map>
#include <string>
#include <unordered_map>
#include "ngraph/except.hpp"
#include "core/operator_set.hpp"
...
...
@@ -24,11 +29,61 @@ namespace ngraph
{
namespace
onnx_import
{
namespace
ops_bridge
namespace
error
{
struct
UnknownOperator
:
ngraph_error
{
explicit
UnknownOperator
(
const
std
::
string
&
op_type
)
:
ngraph_error
{
"unknown operator:
\"
"
+
op_type
+
"
\"
"
}
{
}
};
struct
UnsupportedVersion
:
ngraph_error
{
explicit
UnsupportedVersion
(
std
::
int64_t
version
)
:
ngraph_error
{
"unsupported operator set version: "
+
std
::
to_string
(
version
)}
{
}
};
}
// namespace error
class
OperatorsBridge
{
public
:
OperatorsBridge
(
const
OperatorsBridge
&
)
=
delete
;
OperatorsBridge
&
operator
=
(
const
OperatorsBridge
&
)
=
delete
;
OperatorsBridge
(
OperatorsBridge
&&
)
=
delete
;
OperatorsBridge
&
operator
=
(
OperatorsBridge
&&
)
=
delete
;
static
const
OperatorSet
&
get_operator_set
(
std
::
int64_t
version
)
{
return
instance
().
get_operator_set_version
(
version
);
}
private
:
std
::
unordered_map
<
std
::
string
,
std
::
map
<
std
::
int64_t
,
Operator
>>
m_map
;
OperatorsBridge
();
static
const
OperatorsBridge
&
instance
()
{
const
OperatorSet
&
get_operator_set
(
std
::
int64_t
version
);
static
OperatorsBridge
instance
;
return
instance
;
}
}
// namespace ops_bridge
const
OperatorSet
&
get_operator_set_version_1
()
const
;
const
OperatorSet
&
get_operator_set_version_2
()
const
;
const
OperatorSet
&
get_operator_set_version_3
()
const
;
const
OperatorSet
&
get_operator_set_version_4
()
const
;
const
OperatorSet
&
get_operator_set_version_5
()
const
;
const
OperatorSet
&
get_operator_set_version_6
()
const
;
const
OperatorSet
&
get_operator_set_version_7
()
const
;
const
OperatorSet
&
get_operator_set_version_8
()
const
;
const
OperatorSet
&
get_operator_set_version_9
()
const
;
const
OperatorSet
&
get_operator_set_version
(
std
::
int64_t
version
)
const
;
};
}
// namespace onnx_import
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment