Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
1f350378
Commit
1f350378
authored
Feb 01, 2019
by
tsocha
Committed by
Michał Karzyński
Feb 01, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[ONNX] Enable deselected supported opset domain when needed. (#2350)
parent
676f8d36
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
165 additions
and
36 deletions
+165
-36
attribute.cpp
src/ngraph/frontend/onnx_import/core/attribute.cpp
+2
-2
attribute.hpp
src/ngraph/frontend/onnx_import/core/attribute.hpp
+2
-2
graph.cpp
src/ngraph/frontend/onnx_import/core/graph.cpp
+46
-14
graph.hpp
src/ngraph/frontend/onnx_import/core/graph.hpp
+2
-2
model.cpp
src/ngraph/frontend/onnx_import/core/model.cpp
+23
-2
model.hpp
src/ngraph/frontend/onnx_import/core/model.hpp
+9
-0
onnx.cpp
src/ngraph/frontend/onnx_import/onnx.cpp
+2
-1
ops_bridge.cpp
src/ngraph/frontend/onnx_import/ops_bridge.cpp
+18
-7
ops_bridge.hpp
src/ngraph/frontend/onnx_import/ops_bridge.hpp
+20
-4
missing_op_domain.onnx
test/models/onnx/missing_op_domain.onnx
+17
-0
onnx_import.in.cpp
test/onnx_import.in.cpp
+24
-2
No files found.
src/ngraph/frontend/onnx_import/core/attribute.cpp
View file @
1f350378
...
@@ -22,7 +22,7 @@ namespace ngraph
...
@@ -22,7 +22,7 @@ namespace ngraph
{
{
namespace
onnx_import
namespace
onnx_import
{
{
std
::
vector
<
Graph
>
Attribute
::
get_graph_array
(
const
Model
&
model
)
const
std
::
vector
<
Graph
>
Attribute
::
get_graph_array
(
Model
&
model
)
const
{
{
std
::
vector
<
Graph
>
result
;
std
::
vector
<
Graph
>
result
;
for
(
const
auto
&
graph
:
m_attribute_proto
->
graphs
())
for
(
const
auto
&
graph
:
m_attribute_proto
->
graphs
())
...
@@ -32,7 +32,7 @@ namespace ngraph
...
@@ -32,7 +32,7 @@ namespace ngraph
return
result
;
return
result
;
}
}
Graph
Attribute
::
get_graph
(
const
Model
&
model
)
const
Graph
Attribute
::
get_graph
(
Model
&
model
)
const
{
{
return
Graph
{
m_attribute_proto
->
g
(),
model
};
return
Graph
{
m_attribute_proto
->
g
(),
model
};
}
}
...
...
src/ngraph/frontend/onnx_import/core/attribute.hpp
View file @
1f350378
...
@@ -278,7 +278,7 @@ namespace ngraph
...
@@ -278,7 +278,7 @@ namespace ngraph
float
get_float
()
const
{
return
m_attribute_proto
->
f
();
}
float
get_float
()
const
{
return
m_attribute_proto
->
f
();
}
int64_t
get_integer
()
const
{
return
m_attribute_proto
->
i
();
}
int64_t
get_integer
()
const
{
return
m_attribute_proto
->
i
();
}
const
std
::
string
&
get_string
()
const
{
return
m_attribute_proto
->
s
();
}
const
std
::
string
&
get_string
()
const
{
return
m_attribute_proto
->
s
();
}
Graph
get_graph
(
const
Model
&
)
const
;
Graph
get_graph
(
Model
&
)
const
;
std
::
vector
<
Tensor
>
get_tensor_array
()
const
std
::
vector
<
Tensor
>
get_tensor_array
()
const
{
{
...
@@ -303,7 +303,7 @@ namespace ngraph
...
@@ -303,7 +303,7 @@ namespace ngraph
std
::
end
(
m_attribute_proto
->
strings
())};
std
::
end
(
m_attribute_proto
->
strings
())};
}
}
std
::
vector
<
Graph
>
get_graph_array
(
const
Model
&
)
const
;
std
::
vector
<
Graph
>
get_graph_array
(
Model
&
)
const
;
/* explicit */
operator
onnx
::
AttributeProto_AttributeType
()
const
/* explicit */
operator
onnx
::
AttributeProto_AttributeType
()
const
{
{
...
...
src/ngraph/frontend/onnx_import/core/graph.cpp
View file @
1f350378
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
// limitations under the License.
// limitations under the License.
//*****************************************************************************
//*****************************************************************************
#include <functional>
#include <set>
#include <set>
#include "graph.hpp"
#include "graph.hpp"
...
@@ -25,26 +26,40 @@ namespace ngraph
...
@@ -25,26 +26,40 @@ namespace ngraph
{
{
namespace
detail
namespace
detail
{
{
std
::
string
to_string
(
const
std
::
set
<
std
::
string
>&
set
)
static
std
::
string
to_string
(
const
std
::
map
<
std
::
string
,
std
::
reference_wrapper
<
const
onnx
::
NodeProto
>>&
map
)
{
{
std
::
string
result
;
std
::
string
result
;
for
(
auto
it
=
std
::
begin
(
set
);
it
!=
std
::
end
(
set
);
++
it
)
for
(
auto
it
=
std
::
begin
(
map
);
it
!=
std
::
end
(
map
);
++
it
)
{
{
result
+=
(
it
!=
std
::
begin
(
set
)
?
", "
:
""
)
+
*
i
t
;
result
+=
(
it
!=
std
::
begin
(
map
)
?
", "
:
""
)
+
it
->
firs
t
;
}
}
return
result
;
return
result
;
}
}
inline
std
::
string
to_string
(
const
onnx
::
NodeProto
&
node_proto
)
static
std
::
string
get_node_domain
(
const
onnx
::
NodeProto
&
node_proto
)
{
{
return
(
node_proto
.
domain
().
empty
()
?
""
:
node_proto
.
domain
()
+
"."
)
+
return
(
node_proto
.
domain
().
empty
()
?
""
:
node_proto
.
domain
());
node_proto
.
op_type
();
}
}
}
Graph
::
Graph
(
const
onnx
::
GraphProto
&
graph_proto
,
/// \brief Gets the operator represented by provided node unique identificator.
const
Model
&
model
,
///
const
Weights
&
weights
)
/// \param[in] node_proto The node protobuf representation object.
///
/// \note The operator is uniquely identified by the tuple (domain, op_type,
/// since_version). The first two elements are stored in NodeProto object,
/// thus we use only them.
///
/// \return The unique identificator.
///
static
std
::
string
get_op_domain_and_name
(
const
onnx
::
NodeProto
&
node_proto
)
{
std
::
string
domain
=
get_node_domain
(
node_proto
);
return
(
domain
.
empty
()
?
""
:
domain
+
"."
)
+
node_proto
.
op_type
();
}
}
// namespace detail
Graph
::
Graph
(
const
onnx
::
GraphProto
&
graph_proto
,
Model
&
model
,
const
Weights
&
weights
)
:
m_graph_proto
{
&
graph_proto
}
:
m_graph_proto
{
&
graph_proto
}
,
m_model
{
&
model
}
,
m_model
{
&
model
}
{
{
...
@@ -70,17 +85,34 @@ namespace ngraph
...
@@ -70,17 +85,34 @@ namespace ngraph
}
}
// Verify that ONNX graph contains only nodes of available operator types
// Verify that ONNX graph contains only nodes of available operator types
std
::
set
<
std
::
string
>
unknown_operator_type
s
;
std
::
map
<
std
::
string
,
std
::
reference_wrapper
<
const
onnx
::
NodeProto
>>
unknown_operator
s
;
for
(
const
auto
&
node_proto
:
m_graph_proto
->
node
())
for
(
const
auto
&
node_proto
:
m_graph_proto
->
node
())
{
{
if
(
!
m_model
->
is_operator_available
(
node_proto
))
if
(
!
m_model
->
is_operator_available
(
node_proto
))
{
{
unknown_operator_types
.
emplace
(
detail
::
to_string
(
node_proto
));
unknown_operators
.
emplace
(
detail
::
get_op_domain_and_name
(
node_proto
),
node_proto
);
// Try adding missing domain
m_model
->
enable_opset_domain
(
detail
::
get_node_domain
(
node_proto
));
}
}
// Reverify wheter we still have any unavailable operators.
auto
it
=
std
::
begin
(
unknown_operators
);
while
(
it
!=
std
::
end
(
unknown_operators
))
{
if
(
m_model
->
is_operator_available
(
it
->
second
))
{
it
=
unknown_operators
.
erase
(
it
);
}
else
{
it
++
;
}
}
}
}
NGRAPH_ASSERT
(
unknown_operator
_types
.
empty
())
NGRAPH_ASSERT
(
unknown_operator
s
.
empty
())
<<
"unknown operations: "
<<
"unknown operations: "
<<
detail
::
to_string
(
unknown_operator_type
s
);
<<
detail
::
to_string
(
unknown_operator
s
);
// Process ONNX graph nodes, convert to nGraph nodes
// Process ONNX graph nodes, convert to nGraph nodes
for
(
const
auto
&
node_proto
:
m_graph_proto
->
node
())
for
(
const
auto
&
node_proto
:
m_graph_proto
->
node
())
...
...
src/ngraph/frontend/onnx_import/core/graph.hpp
View file @
1f350378
...
@@ -33,7 +33,7 @@ namespace ngraph
...
@@ -33,7 +33,7 @@ namespace ngraph
class
Graph
class
Graph
{
{
public
:
public
:
Graph
(
const
onnx
::
GraphProto
&
proto
,
const
Model
&
model
,
const
Weights
&
weights
=
{});
Graph
(
const
onnx
::
GraphProto
&
proto
,
Model
&
model
,
const
Weights
&
weights
=
{});
const
std
::
vector
<
Node
>&
get_nodes
()
const
{
return
m_nodes
;
}
const
std
::
vector
<
Node
>&
get_nodes
()
const
{
return
m_nodes
;
}
const
std
::
vector
<
ValueInfo
>&
get_inputs
()
const
{
return
m_inputs
;
}
const
std
::
vector
<
ValueInfo
>&
get_inputs
()
const
{
return
m_inputs
;
}
...
@@ -59,7 +59,7 @@ namespace ngraph
...
@@ -59,7 +59,7 @@ namespace ngraph
ParameterVector
m_parameters
;
ParameterVector
m_parameters
;
std
::
map
<
std
::
string
,
std
::
shared_ptr
<
ngraph
::
Node
>>
m_ng_node_cache
;
std
::
map
<
std
::
string
,
std
::
shared_ptr
<
ngraph
::
Node
>>
m_ng_node_cache
;
std
::
map
<
std
::
string
,
Tensor
>
m_initializers
;
std
::
map
<
std
::
string
,
Tensor
>
m_initializers
;
const
Model
*
m_model
;
Model
*
m_model
;
};
};
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
outs
,
const
Graph
&
graph
)
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
outs
,
const
Graph
&
graph
)
...
...
src/ngraph/frontend/onnx_import/core/model.cpp
View file @
1f350378
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
#include <onnx-ml.pb.h>
#include <onnx-ml.pb.h>
#include "model.hpp"
#include "model.hpp"
#include "ngraph/log.hpp"
#include "ops_bridge.hpp"
#include "ops_bridge.hpp"
namespace
ngraph
namespace
ngraph
...
@@ -33,14 +34,14 @@ namespace ngraph
...
@@ -33,14 +34,14 @@ namespace ngraph
{
{
m_opset
.
emplace
(
id
.
domain
(),
m_opset
.
emplace
(
id
.
domain
(),
OperatorsBridge
::
get_operator_set
(
OperatorsBridge
::
get_operator_set
(
id
.
version
(),
(
id
.
domain
()
==
"ai.onnx"
?
""
:
id
.
domain
()
)));
(
id
.
domain
()
==
"ai.onnx"
?
""
:
id
.
domain
()),
id
.
version
(
)));
}
}
// onnx.proto(.3): the empty string ("") for domain or absence of opset_import field
// onnx.proto(.3): the empty string ("") for domain or absence of opset_import field
// implies the operator set that is defined as part of the ONNX specification.
// implies the operator set that is defined as part of the ONNX specification.
const
auto
dm
=
m_opset
.
find
(
""
);
const
auto
dm
=
m_opset
.
find
(
""
);
if
(
dm
==
std
::
end
(
m_opset
))
if
(
dm
==
std
::
end
(
m_opset
))
{
{
m_opset
.
emplace
(
""
,
OperatorsBridge
::
get_operator_set
(
ONNX_OPSET_VERSION
,
""
));
m_opset
.
emplace
(
""
,
OperatorsBridge
::
get_operator_set
(
""
,
ONNX_OPSET_VERSION
));
}
}
}
}
...
@@ -71,6 +72,26 @@ namespace ngraph
...
@@ -71,6 +72,26 @@ namespace ngraph
return
(
op
!=
std
::
end
(
dm
->
second
));
return
(
op
!=
std
::
end
(
dm
->
second
));
}
}
void
Model
::
enable_opset_domain
(
const
std
::
string
&
domain
)
{
// There is no need to 'update' already enabled domain.
// Since this function may be called only during model import,
// (maybe multiple times) the registered domain opset won't differ
// between subsequent calls.
if
(
m_opset
.
find
(
domain
)
==
std
::
end
(
m_opset
))
{
OperatorSet
opset
{
OperatorsBridge
::
get_operator_set
(
domain
)};
if
(
opset
.
empty
())
{
NGRAPH_WARN
<<
"Couldn't enable domain: "
<<
domain
<<
" since it hasn't any registered operators."
;
return
;
}
m_opset
.
emplace
(
domain
,
opset
);
}
}
}
// namespace onnx_import
}
// namespace onnx_import
}
// namespace ngraph
}
// namespace ngraph
src/ngraph/frontend/onnx_import/core/model.hpp
View file @
1f350378
...
@@ -61,6 +61,15 @@ namespace ngraph
...
@@ -61,6 +61,15 @@ namespace ngraph
/// \return `true` if the operator is available, otherwise it returns `false`.
/// \return `true` if the operator is available, otherwise it returns `false`.
bool
is_operator_available
(
const
onnx
::
NodeProto
&
node_proto
)
const
;
bool
is_operator_available
(
const
onnx
::
NodeProto
&
node_proto
)
const
;
/// \brief Enable operators from provided domain to use by this model.
///
/// \note This function makes visible all currently registered in provided domain
/// operators for use in this model.
///
/// \param[in] domain The domain name.
///
void
enable_opset_domain
(
const
std
::
string
&
domain
);
private
:
private
:
const
onnx
::
ModelProto
*
m_model_proto
;
const
onnx
::
ModelProto
*
m_model_proto
;
std
::
unordered_map
<
std
::
string
,
OperatorSet
>
m_opset
;
std
::
unordered_map
<
std
::
string
,
OperatorSet
>
m_opset
;
...
...
src/ngraph/frontend/onnx_import/onnx.cpp
View file @
1f350378
...
@@ -90,7 +90,8 @@ namespace ngraph
...
@@ -90,7 +90,8 @@ namespace ngraph
std
::
set
<
std
::
string
>
get_supported_operators
(
std
::
int64_t
version
,
std
::
set
<
std
::
string
>
get_supported_operators
(
std
::
int64_t
version
,
const
std
::
string
&
domain
)
const
std
::
string
&
domain
)
{
{
OperatorSet
op_set
{
OperatorsBridge
::
get_operator_set
(
version
,
domain
)};
OperatorSet
op_set
{
OperatorsBridge
::
get_operator_set
(
domain
==
"ai.onnx"
?
""
:
domain
,
version
)};
std
::
set
<
std
::
string
>
op_list
{};
std
::
set
<
std
::
string
>
op_list
{};
for
(
const
auto
&
op
:
op_set
)
for
(
const
auto
&
op
:
op_set
)
{
{
...
...
src/ngraph/frontend/onnx_import/ops_bridge.cpp
View file @
1f350378
...
@@ -110,6 +110,11 @@ namespace ngraph
...
@@ -110,6 +110,11 @@ namespace ngraph
find
(
std
::
int64_t
version
,
const
std
::
map
<
std
::
int64_t
,
Operator
>&
map
)
find
(
std
::
int64_t
version
,
const
std
::
map
<
std
::
int64_t
,
Operator
>&
map
)
{
{
std
::
map
<
std
::
int64_t
,
Operator
>::
const_iterator
it
{};
std
::
map
<
std
::
int64_t
,
Operator
>::
const_iterator
it
{};
// Get the latest version.
if
(
version
==
-
1
)
{
return
map
.
empty
()
?
std
::
end
(
map
)
:
--
std
::
end
(
map
);
}
while
(
version
>
0
)
while
(
version
>
0
)
{
{
it
=
map
.
find
(
version
--
);
it
=
map
.
find
(
version
--
);
...
@@ -127,23 +132,29 @@ namespace ngraph
...
@@ -127,23 +132,29 @@ namespace ngraph
const
std
::
string
&
domain
,
const
std
::
string
&
domain
,
Operator
fn
)
Operator
fn
)
{
{
m_map
[
domain
][
name
].
emplace
(
version
,
std
::
move
(
fn
));
auto
result
=
m_map
[
domain
][
name
].
emplace
(
version
,
std
::
move
(
fn
));
if
(
result
.
second
)
{
NGRAPH_WARN
<<
"Overwriting existing operator: "
<<
domain
+
"."
+
name
+
":"
+
std
::
to_string
(
version
);
}
}
}
OperatorSet
OperatorsBridge
::
_get_operator_set
(
std
::
int64_t
versio
n
,
OperatorSet
OperatorsBridge
::
_get_operator_set
(
const
std
::
string
&
domai
n
,
const
std
::
string
&
domai
n
)
std
::
int64_t
versio
n
)
{
{
OperatorSet
result
;
OperatorSet
result
;
auto
dm
=
m_map
.
find
(
domain
);
auto
dm
=
m_map
.
find
(
domain
);
if
(
dm
==
std
::
end
(
m_map
))
if
(
dm
==
std
::
end
(
m_map
))
{
{
throw
error
::
UnknownDomain
{
domain
};
throw
error
::
UnknownDomain
{
domain
};
}
}
if
(
version
>
OperatorsBridge
::
LATEST_SUPPORTED
_OPSET_VERSION
)
if
(
domain
==
""
&&
version
>
OperatorsBridge
::
LATEST_SUPPORTED_ONNX
_OPSET_VERSION
)
{
{
NGRAPH_WARN
<<
"Currently
operator set version: "
<<
version
<<
" is unsupported."
NGRAPH_WARN
<<
"Currently
ONNX operator set version: "
<<
version
<<
" Falling back to: "
<<
"
is unsupported.
Falling back to: "
<<
OperatorsBridge
::
LATEST_SUPPORTED_OPSET_VERSION
;
<<
OperatorsBridge
::
LATEST_SUPPORTED_O
NNX_O
PSET_VERSION
;
}
}
for
(
const
auto
&
op
:
dm
->
second
)
for
(
const
auto
&
op
:
dm
->
second
)
{
{
...
...
src/ngraph/frontend/onnx_import/ops_bridge.hpp
View file @
1f350378
...
@@ -62,16 +62,17 @@ namespace ngraph
...
@@ -62,16 +62,17 @@ namespace ngraph
class
OperatorsBridge
class
OperatorsBridge
{
{
public
:
public
:
static
constexpr
const
int
LATEST_SUPPORTED_OPSET_VERSION
=
ONNX_OPSET_VERSION
;
static
constexpr
const
int
LATEST_SUPPORTED_O
NNX_O
PSET_VERSION
=
ONNX_OPSET_VERSION
;
OperatorsBridge
(
const
OperatorsBridge
&
)
=
delete
;
OperatorsBridge
(
const
OperatorsBridge
&
)
=
delete
;
OperatorsBridge
&
operator
=
(
const
OperatorsBridge
&
)
=
delete
;
OperatorsBridge
&
operator
=
(
const
OperatorsBridge
&
)
=
delete
;
OperatorsBridge
(
OperatorsBridge
&&
)
=
delete
;
OperatorsBridge
(
OperatorsBridge
&&
)
=
delete
;
OperatorsBridge
&
operator
=
(
OperatorsBridge
&&
)
=
delete
;
OperatorsBridge
&
operator
=
(
OperatorsBridge
&&
)
=
delete
;
static
OperatorSet
get_operator_set
(
std
::
int64_t
version
,
const
std
::
string
&
domain
)
static
OperatorSet
get_operator_set
(
const
std
::
string
&
domain
,
std
::
int64_t
version
=
-
1
)
{
{
return
instance
().
_get_operator_set
(
version
,
domai
n
);
return
instance
().
_get_operator_set
(
domain
,
versio
n
);
}
}
static
void
register_operator
(
const
std
::
string
&
name
,
static
void
register_operator
(
const
std
::
string
&
name
,
...
@@ -90,6 +91,20 @@ namespace ngraph
...
@@ -90,6 +91,20 @@ namespace ngraph
}
}
private
:
private
:
// Registered operators structure
// {
// domain_1: {
// op_type_1: {
// version_1: {func_handle},
// version_2: {func_handle},
// ...
// },
// op_type_2: { ... }
// ...
// },
// domain_2: { ... },
// ...
// }
std
::
unordered_map
<
std
::
string
,
std
::
unordered_map
<
std
::
string
,
std
::
unordered_map
<
std
::
string
,
std
::
map
<
std
::
int64_t
,
Operator
>>>
std
::
unordered_map
<
std
::
string
,
std
::
map
<
std
::
int64_t
,
Operator
>>>
m_map
;
m_map
;
...
@@ -106,7 +121,8 @@ namespace ngraph
...
@@ -106,7 +121,8 @@ namespace ngraph
std
::
int64_t
version
,
std
::
int64_t
version
,
const
std
::
string
&
domain
,
const
std
::
string
&
domain
,
Operator
fn
);
Operator
fn
);
OperatorSet
_get_operator_set
(
std
::
int64_t
version
,
const
std
::
string
&
domain
);
OperatorSet
_get_operator_set
(
const
std
::
string
&
domain
,
std
::
int64_t
version
);
bool
_is_operator_registered
(
const
std
::
string
&
name
,
bool
_is_operator_registered
(
const
std
::
string
&
name
,
std
::
int64_t
version
,
std
::
int64_t
version
,
const
std
::
string
&
domain
);
const
std
::
string
&
domain
);
...
...
test/models/onnx/missing_op_domain.onnx
0 → 100644
View file @
1f350378
ONNXnGraphImporter:o
A
BC" CustomAdd: custom.op compute_graphZ
A
Z
B
b
C
B
\ No newline at end of file
test/onnx_import.in.cpp
View file @
1f350378
...
@@ -1820,6 +1820,29 @@ TEST(onnx_${BACKEND_NAME}, model_space_to_depth_no_blocksize)
...
@@ -1820,6 +1820,29 @@ TEST(onnx_${BACKEND_NAME}, model_space_to_depth_no_blocksize)
std
::
runtime_error
);
std
::
runtime_error
);
}
}
TEST
(
onnx_
$
{
BACKEND_NAME
},
model_missing_op_domain
)
{
onnx_import
::
register_operator
(
"CustomAdd"
,
1
,
"custom.op"
,
[](
const
onnx_import
::
Node
&
node
)
->
NodeVector
{
NodeVector
ng_inputs
{
node
.
get_ng_inputs
()};
return
{
std
::
make_shared
<
ngraph
::
op
::
Add
>
(
ng_inputs
.
at
(
0
),
ng_inputs
.
at
(
1
))};
});
EXPECT_TRUE
(
onnx_import
::
is_operator_supported
(
"CustomAdd"
,
1
,
"custom.op"
));
auto
function
=
onnx_import
::
import_onnx_model
(
file_util
::
path_join
(
SERIALIZED_ZOO
,
"onnx/missing_op_domain.onnx"
));
Inputs
inputs
;
inputs
.
emplace_back
(
std
::
vector
<
float
>
{
0.
f
,
1.
f
,
2.
f
,
3.
f
});
inputs
.
emplace_back
(
std
::
vector
<
float
>
{
0.
f
,
1.
f
,
2.
f
,
3.
f
});
Outputs
expected_output
{
std
::
vector
<
float
>
{
0.
f
,
2.
f
,
4.
f
,
6.
f
}};
Outputs
outputs
{
execute
(
function
,
inputs
,
"${BACKEND_NAME}"
)};
EXPECT_TRUE
(
test
::
all_close_f
(
expected_output
.
front
(),
outputs
.
front
()));
}
TEST
(
onnx_
$
{
BACKEND_NAME
},
model_top_k
)
TEST
(
onnx_
$
{
BACKEND_NAME
},
model_top_k
)
{
{
auto
function
=
auto
function
=
...
@@ -1839,4 +1862,4 @@ TEST(onnx_${BACKEND_NAME}, model_top_k)
...
@@ -1839,4 +1862,4 @@ TEST(onnx_${BACKEND_NAME}, model_top_k)
EXPECT_TRUE
(
test
::
all_close_f
(
expected_values_output
,
values_output
));
EXPECT_TRUE
(
test
::
all_close_f
(
expected_values_output
,
values_output
));
EXPECT_TRUE
(
test
::
all_close
(
expected_indices_output
,
indices_output
));
EXPECT_TRUE
(
test
::
all_close
(
expected_indices_output
,
indices_output
));
}
}
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment