Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
4ecdb791
Commit
4ecdb791
authored
Oct 06, 2017
by
Jai Menon
Committed by
GitHub
Oct 06, 2017
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'master' into jmenon/codegen
parents
d185b48c
65aeb4b5
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
692 additions
and
90 deletions
+692
-90
tensor_view_layout.hpp
src/ngraph/descriptor/layout/tensor_view_layout.hpp
+5
-0
convert.cpp
src/ngraph/ops/convert.cpp
+3
-2
convert.hpp
src/ngraph/ops/convert.hpp
+2
-2
convert.hpp
src/ngraph/runtime/eigen/convert.hpp
+51
-0
external_function.cpp
src/ngraph/runtime/external_function.cpp
+378
-83
execute.cpp
test/execute.cpp
+210
-0
type_prop.cpp
test/type_prop.cpp
+43
-3
No files found.
src/ngraph/descriptor/layout/tensor_view_layout.hpp
View file @
4ecdb791
...
...
@@ -51,6 +51,11 @@ namespace ngraph
/// With non-linear buffers, this will need to be something other than size_t.
virtual
size_t
get_index_offset
(
const
std
::
vector
<
size_t
>&
indices
)
=
0
;
const
element
::
Type
&
get_element_type
()
const
{
return
m_tensor_view
.
get_tensor_view_type
()
->
get_element_type
();
}
const
Shape
&
get_shape
()
const
{
return
m_tensor_view
.
get_tensor_view_type
()
->
get_shape
();
...
...
src/ngraph/ops/convert.cpp
View file @
4ecdb791
...
...
@@ -17,9 +17,10 @@
#include "ngraph/ngraph.hpp"
using
namespace
std
;
using
namespace
ngraph
;
using
namespace
ngraph
::
op
;
void
Convert
::
propagate_types
()
const
element
::
Type
&
Convert
::
propagate_element_types
(
const
element
::
Type
&
arg_element_type
)
const
{
throw
ngraph_error
(
"NIY"
)
;
return
m_element_type
;
}
src/ngraph/ops/convert.hpp
View file @
4ecdb791
...
...
@@ -27,9 +27,9 @@ namespace ngraph
{
}
virtual
const
element
::
Type
&
propagate_element_types
(
const
element
::
Type
&
arg_element_type
)
const
override
;
virtual
std
::
string
description
()
const
override
{
return
"Convert"
;
}
virtual
void
propagate_types
()
override
;
protected
:
const
ngraph
::
element
::
Type
&
m_element_type
;
};
...
...
src/ngraph/runtime/eigen/convert.hpp
0 → 100644
View file @
4ecdb791
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
#include "ngraph/runtime/tensor_view_info.hpp"
namespace
ngraph
{
namespace
runtime
{
namespace
eigen
{
template
<
typename
ETI
,
typename
ETO
>
class
ConvertInstruction
:
public
Instruction
{
public
:
ConvertInstruction
(
const
TensorViewInfo
&
arg
,
const
TensorViewInfo
&
out
)
:
m_arg
(
arg
)
,
m_out
(
out
)
{
}
virtual
void
execute
(
CallFrame
&
call_frame
)
const
override
{
EigenArray1d
<
ETO
>
(
call_frame
,
m_out
)
=
EigenArray1d
<
ETI
>
(
call_frame
,
m_arg
).
template
cast
<
typename
ETO
::
type
>
();
}
protected
:
TensorViewInfo
m_arg
;
TensorViewInfo
m_out
;
};
}
}
}
src/ngraph/runtime/external_function.cpp
View file @
4ecdb791
...
...
@@ -28,6 +28,7 @@
#include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/concatenate.hpp"
#include "ngraph/ops/constant.hpp"
#include "ngraph/ops/convert.hpp"
#include "ngraph/ops/divide.hpp"
#include "ngraph/ops/dot.hpp"
#include "ngraph/ops/equal.hpp"
...
...
@@ -59,6 +60,7 @@
#include "ngraph/runtime/eigen/concat_matrix.hpp"
#include "ngraph/runtime/eigen/concat_vector.hpp"
#include "ngraph/runtime/eigen/constant.hpp"
#include "ngraph/runtime/eigen/convert.hpp"
#include "ngraph/runtime/eigen/copy.hpp"
#include "ngraph/runtime/eigen/divide.hpp"
#include "ngraph/runtime/eigen/dot.hpp"
...
...
@@ -102,19 +104,214 @@ ExternalFunction::ExternalFunction(const std::shared_ptr<ngraph::Function>& func
const std::vector<TensorViewInfo>& in, \
const std::vector<TensorViewInfo>& out)
// Suppress Clang's complaints about the ,##__VA_ARGS__ token-pasting hack, which is a GNU extension
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wgnu-zero-variadic-macro-arguments"
#define DO_ON_ELEMENT_TYPE(et, err_msg, macro, ...) \
{ \
if (et == element::Bool::element_type()) \
{ \
macro(element::Bool, ##__VA_ARGS__); \
} \
else if (et == element::Float32::element_type()) \
{ \
macro(element::Float32, ##__VA_ARGS__); \
} \
else if (et == element::Int8::element_type()) \
{ \
macro(element::Int8, ##__VA_ARGS__); \
} \
else if (et == element::Int32::element_type()) \
{ \
macro(element::Int32, ##__VA_ARGS__); \
} \
else if (et == element::Int64::element_type()) \
{ \
macro(element::Int64, ##__VA_ARGS__); \
} \
else if (et == element::UInt8::element_type()) \
{ \
macro(element::UInt8, ##__VA_ARGS__); \
} \
else if (et == element::UInt32::element_type()) \
{ \
macro(element::UInt32, ##__VA_ARGS__); \
} \
else if (et == element::UInt64::element_type()) \
{ \
macro(element::UInt64, ##__VA_ARGS__); \
} \
else \
{ \
throw ngraph_error(err_msg); \
} \
}
#define DO_ON_NUMERIC_TYPE(et, err_msg, macro, ...) \
{ \
if (et == element::Float32::element_type()) \
{ \
macro(element::Float32, ##__VA_ARGS__); \
} \
else if (et == element::Int8::element_type()) \
{ \
macro(element::Int8, ##__VA_ARGS__); \
} \
else if (et == element::Int32::element_type()) \
{ \
macro(element::Int32, ##__VA_ARGS__); \
} \
else if (et == element::Int64::element_type()) \
{ \
macro(element::Int64, ##__VA_ARGS__); \
} \
else if (et == element::UInt8::element_type()) \
{ \
macro(element::UInt8, ##__VA_ARGS__); \
} \
else if (et == element::UInt32::element_type()) \
{ \
macro(element::UInt32, ##__VA_ARGS__); \
} \
else if (et == element::UInt64::element_type()) \
{ \
macro(element::UInt64, ##__VA_ARGS__); \
} \
else \
{ \
throw ngraph_error(err_msg); \
} \
}
#define DO_ON_SIGNED_NUMERIC_TYPE(et, err_msg, macro, ...) \
{ \
if (et == element::Float32::element_type()) \
{ \
macro(element::Float32, ##__VA_ARGS__); \
} \
else if (et == element::Int8::element_type()) \
{ \
macro(element::Int8, ##__VA_ARGS__); \
} \
else if (et == element::Int32::element_type()) \
{ \
macro(element::Int32, ##__VA_ARGS__); \
} \
else if (et == element::Int64::element_type()) \
{ \
macro(element::Int64, ##__VA_ARGS__); \
} \
else \
{ \
throw ngraph_error(err_msg); \
} \
}
#define REGISTER_INSTRUCTION(op_class, instr_class, ...) \
REGISTER_TO_OP_MAP(op_class) \
{ \
ef->get_instructions()->push_back(make_shared<instr_class>(__VA_ARGS__)); \
}
// Versions the include the descriptor
#define REGISTER_UNOP(op_class, instr_class) \
REGISTER_INSTRUCTION(op_class, instr_class, in[0], out[0])
#define REGISTER_BINOP(op_class, instr_class) \
REGISTER_INSTRUCTION(op_class, instr_class, in[0], in[1], out[0])
#define REGISTER_TERNOP(op_class, instr_class) \
REGISTER_INSTRUCTION(op_class, instr_class, in[0], in[1], in[2], out[0])
#define M_REGISTER_SIGNED_NUMERIC_UNOP(T, instr_class) \
ef->get_instructions()->push_back(make_shared<instr_class<T>>(in[0], out[0]));
#define REGISTER_SIGNED_NUMERIC_UNOP(op_class, instr_class) \
REGISTER_TO_OP_MAP(op_class) \
{ \
const element::Type& et = (dynamic_pointer_cast<const TensorViewType>( \
n->get_arguments().at(0)->get_value_type())) \
->get_element_type(); \
DO_ON_SIGNED_NUMERIC_TYPE( \
et, \
"Internal error: signed numeric unop has unhandled element type", \
M_REGISTER_SIGNED_NUMERIC_UNOP, \
instr_class); \
}
#define M_REGISTER_NUMERIC_UNOP(T, instr_class) \
ef->get_instructions()->push_back(make_shared<instr_class<T>>(in[0], out[0]));
#define REGISTER_NUMERIC_UNOP(op_class, instr_class) \
REGISTER_TO_OP_MAP(op_class) \
{ \
const element::Type& et = (dynamic_pointer_cast<const TensorViewType>( \
n->get_arguments().at(0)->get_value_type())) \
->get_element_type(); \
DO_ON_NUMERIC_TYPE(et, \
"Internal error: numeric unop has unhandled element type", \
M_REGISTER_NUMERIC_UNOP, \
instr_class); \
}
#define M_REGISTER_NUMERIC_BINOP(T, instr_class) \
ef->get_instructions()->push_back(make_shared<instr_class<T>>(in[0], in[1], out[0]));
#define REGISTER_NUMERIC_BINOP(op_class, instr_class) \
REGISTER_TO_OP_MAP(op_class) \
{ \
const element::Type& et = (dynamic_pointer_cast<const TensorViewType>( \
n->get_arguments().at(0)->get_value_type())) \
->get_element_type(); \
DO_ON_NUMERIC_TYPE(et, \
"Internal error: numeric binop has unhandled element type", \
M_REGISTER_NUMERIC_BINOP, \
instr_class); \
}
#define M_REGISTER_POLYMORPHIC_BINOP(T, instr_class) \
ef->get_instructions()->push_back(make_shared<instr_class<T>>(in[0], in[1], out[0]));
#define REGISTER_POLYMORPHIC_BINOP(op_class, instr_class) \
REGISTER_TO_OP_MAP(op_class) \
{ \
const element::Type& et = (dynamic_pointer_cast<const TensorViewType>( \
n->get_arguments().at(0)->get_value_type())) \
->get_element_type(); \
DO_ON_ELEMENT_TYPE(et, \
"Internal error: polymorphic binop has unhandled element type", \
M_REGISTER_POLYMORPHIC_BINOP, \
instr_class); \
}
// Something sneaky here: note the at(1) instead of at(0).
#define M_REGISTER_POLYMORPHIC_TERNOP(T, instr_class) \
ef->get_instructions()->push_back(make_shared<instr_class<T>>(in[0], in[1], in[2], out[0]));
#define REGISTER_POLYMORPHIC_TERNOP(op_class, instr_class) \
REGISTER_TO_OP_MAP(op_class) \
{ \
const element::Type& et = (dynamic_pointer_cast<const TensorViewType>( \
n->get_arguments().at(1)->get_value_type())) \
->get_element_type(); \
DO_ON_ELEMENT_TYPE(et, \
"Internal error: polymorphic ternop has unhandled element type", \
M_REGISTER_POLYMORPHIC_TERNOP, \
instr_class); \
}
#define REGISTER_CONSTANT_INSTRUCTIONS(T) \
{ \
REGISTER_INSTRUCTION( \
op::ScalarConstant<T>, \
runtime::eigen::ConstantInstruction<T>, \
std::vector<T::type>{dynamic_cast<const op::ScalarConstant<T>*>(n)->get_value()}, \
out[0]); \
REGISTER_INSTRUCTION( \
op::TensorConstant<T>, \
runtime::eigen::ConstantInstruction<T>, \
std::vector<T::type>{ \
dynamic_cast<const op::TensorConstant<T>*>(n)->get_value()->get_vector()}, \
out[0]); \
}
#define PUSH_INSTRUCTION(T, instr, ...) \
{ \
ef->get_instructions()->push_back(make_shared<instr<T>>(__VA_ARGS__)); \
}
#define PUSH_POLYMORPHIC_INSTRUCTION(et, err_msg, instr, ...) \
DO_ON_ELEMENT_TYPE(et, err_msg, PUSH_INSTRUCTION, instr, __VA_ARGS__)
#define PUSH_NUMERIC_POLYMORPHIC_INSTRUCTION(et, err_msg, instr, ...) \
DO_ON_NUMERIC_TYPE(et, err_msg, PUSH_INSTRUCTION, instr, __VA_ARGS__)
// Turn off complaint suppression (see above)
#pragma clang diagnostic pop
// Define code generators for handled ops.
ExternalFunction
::
OpMap
&
ExternalFunction
::
get_op_map
()
...
...
@@ -123,34 +320,34 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
static
OpMap
op_map
;
if
(
!
initialized
)
{
REGISTER_
UNOP
(
op
::
Abs
,
runtime
::
eigen
::
AbsInstruction
<
element
::
Float32
>
);
REGISTER_
BINOP
(
op
::
Add
,
runtime
::
eigen
::
AddInstruction
<
element
::
Float32
>
);
REGISTER_BINOP
(
op
::
Divide
,
runtime
::
eigen
::
DivideInstruction
<
element
::
Float32
>
);
REGISTER_
BINOP
(
op
::
Equal
,
runtime
::
eigen
::
EqualInstruction
<
element
::
Float32
>
);
REGISTER_BINOP
(
op
::
Greater
,
runtime
::
eigen
::
GreaterThanInstruction
<
element
::
Float32
>
);
REGISTER_
BINOP
(
op
::
GreaterEq
,
runtime
::
eigen
::
GreaterEqInstruction
<
element
::
Float32
>
);
REGISTER_
BINOP
(
op
::
Less
,
runtime
::
eigen
::
LessThanInstruction
<
element
::
Float32
>
);
REGISTER_
BINOP
(
op
::
LessEq
,
runtime
::
eigen
::
LessEqInstruction
<
element
::
Float32
>
);
REGISTER_
UNOP
(
op
::
Log
,
runtime
::
eigen
::
LogInstruction
<
element
::
Float32
>
);
REGISTER_
BINOP
(
op
::
Maximum
,
runtime
::
eigen
::
MaximumInstruction
<
element
::
Float32
>
);
REGISTER_
BINOP
(
op
::
Multiply
,
runtime
::
eigen
::
MultiplyInstruction
<
element
::
Float32
>
);
REGISTER_
UNOP
(
op
::
Negative
,
runtime
::
eigen
::
NegateInstruction
<
element
::
Float32
>
);
REGISTER_
BINOP
(
op
::
NotEqual
,
runtime
::
eigen
::
NotEqualInstruction
<
element
::
Float32
>
);
REGISTER_
TERNOP
(
op
::
Select
,
runtime
::
eigen
::
SelectInstruction
<
element
::
Float32
>
);
REGISTER_BINOP
(
op
::
Subtract
,
runtime
::
eigen
::
SubtractInstruction
<
element
::
Float32
>
);
REGISTER_
INSTRUCTION
(
op
::
ScalarConstant
<
element
::
Float32
>
,
runtime
::
eigen
::
ConstantInstruction
<
element
::
Float32
>
,
std
::
vector
<
element
::
Float32
::
type
>
{
dynamic_cast
<
const
op
::
ScalarConstant
<
element
::
Float32
>*>
(
n
)
->
get_value
()},
out
[
0
]
);
REGISTER_
INSTRUCTION
(
op
::
TensorConstant
<
element
::
Float32
>
,
runtime
::
eigen
::
ConstantInstruction
<
element
::
Float32
>
,
dynamic_cast
<
const
op
::
TensorConstant
<
element
::
Float32
>*>
(
n
)
->
get_value
()
->
get_vector
(),
out
[
0
]
);
REGISTER_
NUMERIC_UNOP
(
op
::
Log
,
runtime
::
eigen
::
LogInstruction
);
REGISTER_
NUMERIC_UNOP
(
op
::
Negative
,
runtime
::
eigen
::
NegateInstruction
);
REGISTER_
SIGNED_NUMERIC_UNOP
(
op
::
Abs
,
runtime
::
eigen
::
AbsInstruction
);
REGISTER_
NUMERIC_BINOP
(
op
::
Add
,
runtime
::
eigen
::
AddInstruction
);
REGISTER_
NUMERIC_BINOP
(
op
::
Divide
,
runtime
::
eigen
::
DivideInstruction
);
REGISTER_
NUMERIC_BINOP
(
op
::
Greater
,
runtime
::
eigen
::
GreaterThanInstruction
);
REGISTER_
NUMERIC_BINOP
(
op
::
GreaterEq
,
runtime
::
eigen
::
GreaterEqInstruction
);
REGISTER_
NUMERIC_BINOP
(
op
::
Less
,
runtime
::
eigen
::
LessThanInstruction
);
REGISTER_
NUMERIC_BINOP
(
op
::
LessEq
,
runtime
::
eigen
::
LessEqInstruction
);
REGISTER_
NUMERIC_BINOP
(
op
::
Maximum
,
runtime
::
eigen
::
MaximumInstruction
);
REGISTER_
NUMERIC_BINOP
(
op
::
Multiply
,
runtime
::
eigen
::
MultiplyInstruction
);
REGISTER_
NUMERIC_BINOP
(
op
::
Subtract
,
runtime
::
eigen
::
SubtractInstruction
);
REGISTER_POLYMORPHIC_BINOP
(
op
::
Equal
,
runtime
::
eigen
::
EqualInstruction
);
REGISTER_
POLYMORPHIC_BINOP
(
op
::
NotEqual
,
runtime
::
eigen
::
NotEqualInstruction
);
REGISTER_POLYMORPHIC_TERNOP
(
op
::
Select
,
runtime
::
eigen
::
SelectInstruction
);
REGISTER_CONSTANT_INSTRUCTIONS
(
element
::
Bool
);
REGISTER_CONSTANT_INSTRUCTIONS
(
element
::
Float32
);
REGISTER_CONSTANT_INSTRUCTIONS
(
element
::
Int8
);
REGISTER_
CONSTANT_INSTRUCTIONS
(
element
::
Int32
);
REGISTER_CONSTANT_INSTRUCTIONS
(
element
::
Int64
);
REGISTER_CONSTANT_INSTRUCTIONS
(
element
::
UInt8
);
REGISTER_CONSTANT_INSTRUCTIONS
(
element
::
UInt32
);
REGISTER_CONSTANT_INSTRUCTIONS
(
element
::
UInt64
);
REGISTER_TO_OP_MAP
(
op
::
Broadcast
)
{
...
...
@@ -166,40 +363,46 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
auto
arg_shape
=
arg_tensor_type
->
get_shape
();
auto
result_shape
=
result_tensor_type
->
get_shape
();
auto
&
result_element_type
=
result_tensor_type
->
get_element_type
();
if
(
broadcast
->
get_broadcast_axes
().
empty
())
{
// Degenerate case: no broadcast axes is just a copy.
ef
->
get_instructions
()
->
push_back
(
make_shared
<
runtime
::
eigen
::
CopyInstruction
<
element
::
Float32
>>
(
in
[
0
].
get_index
(),
out
[
0
].
get_index
()));
PUSH_POLYMORPHIC_INSTRUCTION
(
result_element_type
,
"Broadcast has unhandled element type"
,
runtime
::
eigen
::
CopyInstruction
,
in
[
0
].
get_index
(),
out
[
0
].
get_index
());
}
else
if
(
arg_shape
.
size
()
==
0
)
{
ef
->
get_instructions
()
->
push_back
(
make_shared
<
runtime
::
eigen
::
BroadcastScalarInstruction
<
element
::
Float32
>>
(
in
[
0
],
out
[
0
]));
PUSH_POLYMORPHIC_INSTRUCTION
(
result_element_type
,
"Broadcast has unhandled element type"
,
runtime
::
eigen
::
BroadcastScalarInstruction
,
in
[
0
],
out
[
0
]);
}
else
if
(
arg_shape
.
size
()
==
1
&&
result_shape
.
size
()
==
2
)
{
if
(
broadcast
->
get_broadcast_axes
()
==
AxisSet
{
1
})
{
ef
->
get_instructions
()
->
push_back
(
make_shared
<
runtime
::
eigen
::
BroadcastVectorColwiseInstruction
<
element
::
Float32
>>
(
in
[
0
],
out
[
0
]));
PUSH_POLYMORPHIC_INSTRUCTION
(
result_element_type
,
"Broadcast has unhandled element type"
,
runtime
::
eigen
::
BroadcastVectorColwiseInstruction
,
in
[
0
],
out
[
0
]);
}
else
if
(
broadcast
->
get_broadcast_axes
()
==
AxisSet
{
0
})
{
ef
->
get_instructions
()
->
push_back
(
make_shared
<
runtime
::
eigen
::
BroadcastVectorRowwiseInstruction
<
element
::
Float32
>>
(
in
[
0
],
out
[
0
]));
PUSH_POLYMORPHIC_INSTRUCTION
(
result_element_type
,
"Broadcast has unhandled element type"
,
runtime
::
eigen
::
BroadcastVectorRowwiseInstruction
,
in
[
0
],
out
[
0
]);
}
else
{
throw
ngraph_error
(
"Internal error: axis set for vector-matrix broadcast is neither {0} or "
"Internal error: axis set for vector-matrix broadcast is neither {0}
n
or "
"{1}"
);
}
}
...
...
@@ -216,20 +419,25 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
assert
(
nullptr
!=
result_tensor_type
);
auto
result_shape
=
result_tensor_type
->
get_shape
();
auto
&
result_element_type
=
result_tensor_type
->
get_element_type
();
if
(
result_shape
.
size
()
==
1
)
{
ef
->
get_instructions
()
->
push_back
(
make_shared
<
runtime
::
eigen
::
ConcatVectorInstruction
<
element
::
Float32
>>
(
in
,
out
[
0
]));
PUSH_POLYMORPHIC_INSTRUCTION
(
result_element_type
,
"Concat has unhandled element type"
,
runtime
::
eigen
::
ConcatVectorInstruction
,
in
,
out
[
0
]);
}
else
if
(
result_shape
.
size
()
==
2
)
{
ef
->
get_instructions
()
->
push_back
(
make_shared
<
runtime
::
eigen
::
ConcatMatrixInstruction
<
element
::
Float32
>>
(
in
,
(
dynamic_cast
<
const
op
::
Concat
*>
(
n
))
->
get_concatenation_axis
(),
out
[
0
]));
PUSH_POLYMORPHIC_INSTRUCTION
(
result_element_type
,
"Concat has unhandled element type"
,
runtime
::
eigen
::
ConcatMatrixInstruction
,
in
,
(
dynamic_cast
<
const
op
::
Concat
*>
(
n
))
->
get_concatenation_axis
(),
out
[
0
]);
}
else
{
...
...
@@ -237,6 +445,62 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
}
};
REGISTER_TO_OP_MAP
(
op
::
Convert
)
{
auto
arg
=
n
->
get_arguments
().
at
(
0
);
auto
arg_tensor_type
=
dynamic_pointer_cast
<
const
TensorViewType
>
(
arg
->
get_value_type
());
assert
(
nullptr
!=
arg_tensor_type
);
auto
&
arg_element_type
=
arg_tensor_type
->
get_element_type
();
auto
result_tensor_type
=
dynamic_pointer_cast
<
const
TensorViewType
>
(
n
->
get_value_type
());
assert
(
nullptr
!=
result_tensor_type
);
auto
&
result_element_type
=
result_tensor_type
->
get_element_type
();
// Hacky macro: we are going to be building up a series of else-ifs for each possible
// pair of element types.
#define REGISTER_CONVERT(TI, TO) \
else if (arg_element_type == (TI::element_type()) && \
result_element_type == (TO::element_type())) \
{ \
ef->get_instructions()->push_back( \
make_shared<runtime::eigen::ConvertInstruction<TI, TO>>(in[0], out[0])); \
}
// End hacky macro
// Hacky macro: Given some type TI, generate the else-ifs for TI to every other element
// type.
#define REGISTER_CONVERTS(TI) \
REGISTER_CONVERT(TI, element::Bool) \
REGISTER_CONVERT(TI, element::Float32) \
REGISTER_CONVERT(TI, element::Int8) \
REGISTER_CONVERT(TI, element::Int32) \
REGISTER_CONVERT(TI, element::Int64) \
REGISTER_CONVERT(TI, element::UInt8) \
REGISTER_CONVERT(TI, element::UInt32) \
REGISTER_CONVERT(TI, element::UInt64)
// End hacky macro
if
(
false
)
{
}
REGISTER_CONVERTS
(
element
::
Bool
)
REGISTER_CONVERTS
(
element
::
Float32
)
REGISTER_CONVERTS
(
element
::
Int8
)
REGISTER_CONVERTS
(
element
::
Int32
)
REGISTER_CONVERTS
(
element
::
Int64
)
REGISTER_CONVERTS
(
element
::
UInt8
)
REGISTER_CONVERTS
(
element
::
UInt32
)
REGISTER_CONVERTS
(
element
::
UInt64
)
else
{
throw
ngraph_error
(
"Internal error: cannot convert between element types"
);
}
#undef REGISTER_CONVERTS
#undef REGISTER_CONVERT
};
REGISTER_TO_OP_MAP
(
op
::
Dot
)
{
auto
&
arg_nodes
=
n
->
get_arguments
();
...
...
@@ -253,44 +517,59 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
auto
arg0_shape
=
arg0_tensor_type
->
get_shape
();
auto
arg1_shape
=
arg1_tensor_type
->
get_shape
();
auto
&
arg0_element_type
=
arg0_tensor_type
->
get_element_type
();
// If arg0 or arg1 is a scalar, emit a scalar-tensor product.
if
(
arg0_shape
.
size
()
==
0
)
{
ef
->
get_instructions
()
->
push_back
(
make_shared
<
runtime
::
eigen
::
ScalarTensorProductInstruction
<
element
::
Float32
>>
(
in
[
0
],
in
[
1
],
out
[
0
]));
PUSH_NUMERIC_POLYMORPHIC_INSTRUCTION
(
arg0_element_type
,
"Dot has unhandled element type"
,
runtime
::
eigen
::
ScalarTensorProductInstruction
,
in
[
0
],
in
[
1
],
out
[
0
]);
}
else
if
(
arg1_shape
.
size
()
==
0
)
{
// If arg1 is the scalar, do the same thing but switch the order of operands.
ef
->
get_instructions
()
->
push_back
(
make_shared
<
runtime
::
eigen
::
ScalarTensorProductInstruction
<
element
::
Float32
>>
(
in
[
1
],
in
[
0
],
out
[
0
]));
PUSH_NUMERIC_POLYMORPHIC_INSTRUCTION
(
arg0_element_type
,
"Dot has unhandled element type"
,
runtime
::
eigen
::
ScalarTensorProductInstruction
,
in
[
1
],
in
[
0
],
out
[
0
]);
}
// If arg0 and arg1 are both vectors, emit a dot product.
else
if
(
arg0_shape
.
size
()
==
1
&&
arg1_shape
.
size
()
==
1
)
{
ef
->
get_instructions
()
->
push_back
(
make_shared
<
runtime
::
eigen
::
DotInstruction
<
element
::
Float32
>>
(
in
[
0
],
in
[
1
],
out
[
0
]));
PUSH_NUMERIC_POLYMORPHIC_INSTRUCTION
(
arg0_element_type
,
"Dot has unhandled element type"
,
runtime
::
eigen
::
DotInstruction
,
in
[
0
],
in
[
1
],
out
[
0
]);
}
// If arg0 is a matrix and arg1 is a vector, emit a matrix-vector product.
else
if
(
arg0_shape
.
size
()
==
2
&&
arg1_shape
.
size
()
==
1
)
{
ef
->
get_instructions
()
->
push_back
(
make_shared
<
runtime
::
eigen
::
MatrixVectorProductInstruction
<
element
::
Float32
>>
(
in
[
0
],
in
[
1
],
out
[
0
]));
PUSH_NUMERIC_POLYMORPHIC_INSTRUCTION
(
arg0_element_type
,
"Dot has unhandled element type"
,
runtime
::
eigen
::
MatrixVectorProductInstruction
,
in
[
0
],
in
[
1
],
out
[
0
]);
}
// If arg0 and arg1 are both matrices, emit a matrix product.
else
if
(
arg0_shape
.
size
()
==
2
&&
arg1_shape
.
size
()
==
2
)
{
ef
->
get_instructions
()
->
push_back
(
make_shared
<
runtime
::
eigen
::
MatrixMultInstruction
<
element
::
Float32
>>
(
in
[
0
],
in
[
1
],
out
[
0
]));
PUSH_NUMERIC_POLYMORPHIC_INSTRUCTION
(
arg0_element_type
,
"Dot has unhandled element type"
,
runtime
::
eigen
::
MatrixMultInstruction
,
in
[
0
],
in
[
1
],
out
[
0
]);
}
else
...
...
@@ -307,9 +586,17 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
{
auto
get_tuple_element
=
static_cast
<
const
op
::
GetTupleElement
*>
(
n
);
ef
->
get_instructions
()
->
push_back
(
make_shared
<
runtime
::
eigen
::
CopyInstruction
<
element
::
Float32
>>
(
in
.
at
(
get_tuple_element
->
get_n
()).
get_index
(),
out
.
at
(
0
).
get_index
()));
auto
result_tensor_type
=
dynamic_pointer_cast
<
const
TensorViewType
>
(
n
->
get_value_type
());
assert
(
nullptr
!=
result_tensor_type
);
auto
&
result_element_type
=
result_tensor_type
->
get_element_type
();
PUSH_POLYMORPHIC_INSTRUCTION
(
result_element_type
,
"GetTupleElement has unhandled element type"
,
runtime
::
eigen
::
CopyInstruction
,
in
.
at
(
get_tuple_element
->
get_n
()).
get_index
(),
out
.
at
(
0
).
get_index
());
};
// Tuple will be spliced out, with the users of out connected to the corresponding in's source, but, for now, we need to copy.
...
...
@@ -317,9 +604,12 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
{
for
(
size_t
i
=
0
;
i
<
in
.
size
();
++
i
)
{
ef
->
get_instructions
()
->
push_back
(
make_shared
<
runtime
::
eigen
::
CopyInstruction
<
element
::
Float32
>>
(
in
.
at
(
i
).
get_index
(),
out
.
at
(
i
).
get_index
()));
auto
&
et
=
in
.
at
(
i
).
get_tensor_view_layout
()
->
get_element_type
();
PUSH_POLYMORPHIC_INSTRUCTION
(
et
,
"Tuple has unhandled element type"
,
runtime
::
eigen
::
CopyInstruction
,
in
.
at
(
i
).
get_index
(),
out
.
at
(
i
).
get_index
());
}
};
...
...
@@ -467,8 +757,13 @@ shared_ptr<ngraph::runtime::CallFrame> ExternalFunction::make_call_frame(Functio
std
::
vector
<
std
::
shared_ptr
<
ngraph
::
runtime
::
TensorView
>>
temps
;
for
(
auto
tv
:
m_temp_views
)
{
temps
.
push_back
(
ngraph
::
runtime
::
make_tensor
<
ngraph
::
element
::
Float32
>
(
tv
->
get_tensor_view_type
()
->
get_shape
()));
auto
&
et
=
tv
->
get_tensor_view_type
()
->
get_element_type
();
auto
shape
=
tv
->
get_tensor_view_type
()
->
get_shape
();
#define M(T) temps.push_back(ngraph::runtime::make_tensor<T>(shape));
DO_ON_ELEMENT_TYPE
(
et
,
"Internal error: tried to create temporary for unhandled element type"
,
M
);
#undef M
}
return
make_shared
<
ngraph
::
runtime
::
CallFrame
>
(
m_n_inputs
,
m_n_outputs
,
temps
,
0
,
m_instructions
);
...
...
test/execute.cpp
View file @
4ecdb791
...
...
@@ -14,6 +14,7 @@
#include "gtest/gtest.h"
#include <cmath>
#include "ngraph/ngraph.hpp"
using
namespace
std
;
...
...
@@ -50,6 +51,37 @@ TEST(execute, test_abc)
ASSERT_EQ
((
vector
<
float
>
{
50
,
72
,
98
,
128
}),
result
->
get_vector
());
}
TEST
(
execute
,
test_abc_int64
)
{
auto
shape
=
Shape
{
2
,
2
};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
Int64
::
element_type
(),
shape
);
auto
B
=
make_shared
<
op
::
Parameter
>
(
element
::
Int64
::
element_type
(),
shape
);
auto
C
=
make_shared
<
op
::
Parameter
>
(
element
::
Int64
::
element_type
(),
shape
);
auto
rt
=
make_shared
<
TensorViewType
>
(
element
::
Int64
::
element_type
(),
shape
);
auto
f
=
make_shared
<
Function
>
((
A
+
B
)
*
C
,
rt
,
op
::
Parameters
{
A
,
B
,
C
});
auto
external
=
make_shared
<
ngraph
::
runtime
::
ExternalFunction
>
(
f
);
auto
cf
=
external
->
make_call_frame
();
// Create some tensors for input/output
auto
a
=
ngraph
::
runtime
::
make_tensor
<
element
::
Int64
>
(
shape
);
*
a
=
vector
<
element
::
Int64
::
type
>
{
1
,
2
,
3
,
4
};
auto
b
=
ngraph
::
runtime
::
make_tensor
<
element
::
Int64
>
(
shape
);
*
b
=
vector
<
element
::
Int64
::
type
>
{
5
,
6
,
7
,
8
};
auto
c
=
ngraph
::
runtime
::
make_tensor
<
element
::
Int64
>
(
shape
);
*
c
=
vector
<
element
::
Int64
::
type
>
{
9
,
10
,
11
,
12
};
auto
result
=
ngraph
::
runtime
::
make_tensor
<
element
::
Int64
>
(
shape
);
(
*
cf
)({
a
,
b
,
c
},
{
result
});
ASSERT_EQ
((
vector
<
element
::
Int64
::
type
>
{
54
,
80
,
110
,
144
}),
result
->
get_vector
());
(
*
cf
)({
b
,
a
,
c
},
{
result
});
ASSERT_EQ
((
vector
<
element
::
Int64
::
type
>
{
54
,
80
,
110
,
144
}),
result
->
get_vector
());
(
*
cf
)({
a
,
c
,
b
},
{
result
});
ASSERT_EQ
((
vector
<
element
::
Int64
::
type
>
{
50
,
72
,
98
,
128
}),
result
->
get_vector
());
}
// Same as test_abc, but using tuples for input and output
TEST
(
execute
,
test_abc_tuple
)
{
...
...
@@ -92,6 +124,48 @@ TEST(execute, test_abc_tuple)
ASSERT_EQ
((
vector
<
float
>
{
50
,
72
,
98
,
128
}),
result
->
get_vector
());
}
// Same as test_abc, but using tuples for input and output
TEST
(
execute
,
test_abc_tuple_int64
)
{
auto
shape
=
Shape
{
2
,
2
};
auto
tensor_view_type
=
make_shared
<
TensorViewType
>
(
element
::
Int64
::
element_type
(),
shape
);
auto
ABC
=
make_shared
<
op
::
Parameter
>
(
make_shared
<
TupleType
>
(
ValueTypes
{
tensor_view_type
,
tensor_view_type
,
tensor_view_type
}));
auto
A
=
make_shared
<
op
::
GetTupleElement
>
(
ABC
,
0
);
auto
B
=
make_shared
<
op
::
GetTupleElement
>
(
ABC
,
1
);
auto
C
=
make_shared
<
op
::
GetTupleElement
>
(
ABC
,
2
);
auto
f
=
make_shared
<
Function
>
(
make_shared
<
op
::
Tuple
>
(
Nodes
{(
A
+
B
)
*
C
}),
tensor_view_type
,
op
::
Parameters
{
ABC
});
auto
external
=
make_shared
<
ngraph
::
runtime
::
ExternalFunction
>
(
f
);
auto
cf
=
external
->
make_call_frame
();
// Create some tensors for input/output
auto
a
=
ngraph
::
runtime
::
make_tensor
<
element
::
Int64
>
(
shape
);
*
a
=
vector
<
element
::
Int64
::
type
>
{
1
,
2
,
3
,
4
};
auto
b
=
ngraph
::
runtime
::
make_tensor
<
element
::
Int64
>
(
shape
);
*
b
=
vector
<
element
::
Int64
::
type
>
{
5
,
6
,
7
,
8
};
auto
c
=
ngraph
::
runtime
::
make_tensor
<
element
::
Int64
>
(
shape
);
*
c
=
vector
<
element
::
Int64
::
type
>
{
9
,
10
,
11
,
12
};
auto
abc
=
ngraph
::
runtime
::
make_tuple
({
a
,
b
,
c
});
auto
bac
=
ngraph
::
runtime
::
make_tuple
({
b
,
a
,
c
});
auto
acb
=
ngraph
::
runtime
::
make_tuple
({
a
,
c
,
b
});
auto
result
=
ngraph
::
runtime
::
make_tensor
<
element
::
Int64
>
(
shape
);
auto
result_tuple
=
ngraph
::
runtime
::
make_tuple
({
result
});
(
*
cf
)({
abc
},
{
result_tuple
});
ASSERT_EQ
((
vector
<
element
::
Int64
::
type
>
{
54
,
80
,
110
,
144
}),
result
->
get_vector
());
(
*
cf
)({
bac
},
{
result_tuple
});
ASSERT_EQ
((
vector
<
element
::
Int64
::
type
>
{
54
,
80
,
110
,
144
}),
result
->
get_vector
());
(
*
cf
)({
acb
},
{
result_tuple
});
ASSERT_EQ
((
vector
<
element
::
Int64
::
type
>
{
50
,
72
,
98
,
128
}),
result
->
get_vector
());
}
// Multiple retrive values
TEST
(
execute
,
test_tuple_result
)
{
...
...
@@ -206,6 +280,36 @@ TEST(execute, test_concat_matrix_rowwise)
result
->
get_vector
());
}
TEST
(
execute
,
test_concat_matrix_int64
)
{
auto
shape_a
=
Shape
{
2
,
2
};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
Int64
::
element_type
(),
shape_a
);
auto
shape_b
=
Shape
{
3
,
2
};
auto
B
=
make_shared
<
op
::
Parameter
>
(
element
::
Int64
::
element_type
(),
shape_b
);
auto
shape_c
=
Shape
{
3
,
2
};
auto
C
=
make_shared
<
op
::
Parameter
>
(
element
::
Int64
::
element_type
(),
shape_c
);
auto
shape_r
=
Shape
{
8
,
2
};
auto
rt
=
make_shared
<
TensorViewType
>
(
element
::
Int64
::
element_type
(),
Shape
{
8
,
2
});
auto
f
=
make_shared
<
Function
>
(
make_shared
<
op
::
Concat
>
(
Nodes
{
A
,
B
,
C
},
0
),
rt
,
op
::
Parameters
{
A
,
B
,
C
});
auto
external
=
make_shared
<
ngraph
::
runtime
::
ExternalFunction
>
(
f
);
auto
cf
=
external
->
make_call_frame
();
// Create some tensors for input/output
auto
a
=
ngraph
::
runtime
::
make_tensor
<
element
::
Int64
>
(
shape_a
);
*
a
=
vector
<
element
::
Int64
::
type
>
{
2
,
4
,
8
,
16
};
auto
b
=
ngraph
::
runtime
::
make_tensor
<
element
::
Int64
>
(
shape_b
);
*
b
=
vector
<
element
::
Int64
::
type
>
{
1
,
2
,
4
,
8
,
16
,
32
};
auto
c
=
ngraph
::
runtime
::
make_tensor
<
element
::
Int64
>
(
shape_c
);
*
c
=
vector
<
element
::
Int64
::
type
>
{
2
,
3
,
5
,
7
,
11
,
13
};
auto
result
=
ngraph
::
runtime
::
make_tensor
<
element
::
Int64
>
(
shape_r
);
(
*
cf
)({
a
,
b
,
c
},
{
result
});
ASSERT_EQ
((
vector
<
element
::
Int64
::
type
>
{
2
,
4
,
8
,
16
,
1
,
2
,
4
,
8
,
16
,
32
,
2
,
3
,
5
,
7
,
11
,
13
}),
result
->
get_vector
());
}
TEST
(
execute
,
test_concat_vector
)
{
auto
shape_a
=
Shape
{
4
};
...
...
@@ -560,6 +664,30 @@ TEST(execute, test_dot_matrix_vector)
ASSERT_EQ
((
vector
<
float
>
{
190
,
486
,
782
,
1078
}),
result
->
get_vector
());
}
TEST
(
execute
,
test_dot_matrix_vector_int64
)
{
auto
shape_a
=
Shape
{
4
,
4
};
auto
shape_b
=
Shape
{
4
};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
Int64
::
element_type
(),
shape_a
);
auto
B
=
make_shared
<
op
::
Parameter
>
(
element
::
Int64
::
element_type
(),
shape_b
);
auto
rt
=
make_shared
<
TensorViewType
>
(
element
::
Int64
::
element_type
(),
shape_b
);
auto
f
=
make_shared
<
Function
>
(
make_shared
<
op
::
Dot
>
(
A
,
B
),
rt
,
op
::
Parameters
{
A
,
B
});
auto
shape_r
=
Shape
{
4
};
auto
external
=
make_shared
<
ngraph
::
runtime
::
ExternalFunction
>
(
f
);
auto
cf
=
external
->
make_call_frame
();
// Create some tensors for input/output
auto
a
=
ngraph
::
runtime
::
make_tensor
<
element
::
Int64
>
(
shape_a
);
*
a
=
vector
<
element
::
Int64
::
type
>
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
};
auto
b
=
ngraph
::
runtime
::
make_tensor
<
element
::
Int64
>
(
shape_b
);
*
b
=
vector
<
element
::
Int64
::
type
>
{
17
,
18
,
19
,
20
};
auto
result
=
ngraph
::
runtime
::
make_tensor
<
element
::
Int64
>
(
shape_r
);
(
*
cf
)({
a
,
b
},
{
result
});
ASSERT_EQ
((
vector
<
element
::
Int64
::
type
>
{
190
,
486
,
782
,
1078
}),
result
->
get_vector
());
}
TEST
(
execute
,
test_greater
)
{
auto
shape
=
Shape
{
2
,
2
,
2
};
...
...
@@ -1001,3 +1129,85 @@ TEST(execute, test_broadcast_vector_rowwise)
(
*
cf
)({
a
},
{
result
});
ASSERT_EQ
((
vector
<
float
>
{
1
,
2
,
3
,
4
,
1
,
2
,
3
,
4
,
1
,
2
,
3
,
4
}),
result
->
get_vector
());
}
TEST
(
execute
,
test_broadcast_vector_rowwise_int64
)
{
auto
shape_a
=
Shape
{
4
};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
Int64
::
element_type
(),
shape_a
);
auto
shape_r
=
Shape
{
3
,
4
};
auto
rt
=
make_shared
<
TensorViewType
>
(
element
::
Int64
::
element_type
(),
shape_r
);
auto
f
=
make_shared
<
Function
>
(
make_shared
<
op
::
Broadcast
>
(
A
,
shape_r
,
AxisSet
{
0
}),
rt
,
op
::
Parameters
{
A
});
auto
external
=
make_shared
<
ngraph
::
runtime
::
ExternalFunction
>
(
f
);
auto
cf
=
external
->
make_call_frame
();
// Create some tensors for input/output
auto
a
=
ngraph
::
runtime
::
make_tensor
<
element
::
Int64
>
(
shape_a
);
*
a
=
vector
<
element
::
Int64
::
type
>
{
1
,
2
,
3
,
4
};
auto
result
=
ngraph
::
runtime
::
make_tensor
<
element
::
Int64
>
(
shape_r
);
(
*
cf
)({
a
},
{
result
});
ASSERT_EQ
((
vector
<
element
::
Int64
::
type
>
{
1
,
2
,
3
,
4
,
1
,
2
,
3
,
4
,
1
,
2
,
3
,
4
}),
result
->
get_vector
());
}
TEST
(
execute
,
test_convert_int32_float32
)
{
auto
shape
=
Shape
{
2
,
2
};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
Int32
::
element_type
(),
shape
);
auto
rt
=
make_shared
<
TensorViewType
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
f
=
make_shared
<
Function
>
(
make_shared
<
op
::
Convert
>
(
A
,
element
::
Float32
::
element_type
()),
rt
,
op
::
Parameters
{
A
});
auto
external
=
make_shared
<
ngraph
::
runtime
::
ExternalFunction
>
(
f
);
auto
cf
=
external
->
make_call_frame
();
// Create some tensors for input/output
auto
a
=
ngraph
::
runtime
::
make_tensor
<
element
::
Int32
>
(
shape
);
*
a
=
vector
<
element
::
Int32
::
type
>
{
1
,
2
,
3
,
4
};
auto
result
=
ngraph
::
runtime
::
make_tensor
<
element
::
Float32
>
(
shape
);
(
*
cf
)({
a
},
{
result
});
ASSERT_EQ
((
vector
<
element
::
Float32
::
type
>
{
1
,
2
,
3
,
4
}),
result
->
get_vector
());
}
TEST
(
execute
,
test_convert_int32_bool
)
{
auto
shape
=
Shape
{
2
,
2
};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
Int32
::
element_type
(),
shape
);
auto
rt
=
make_shared
<
TensorViewType
>
(
element
::
Bool
::
element_type
(),
shape
);
auto
f
=
make_shared
<
Function
>
(
make_shared
<
op
::
Convert
>
(
A
,
element
::
Bool
::
element_type
()),
rt
,
op
::
Parameters
{
A
});
auto
external
=
make_shared
<
ngraph
::
runtime
::
ExternalFunction
>
(
f
);
auto
cf
=
external
->
make_call_frame
();
// Create some tensors for input/output
auto
a
=
ngraph
::
runtime
::
make_tensor
<
element
::
Int32
>
(
shape
);
*
a
=
vector
<
element
::
Int32
::
type
>
{
1
,
2
,
3
,
4
};
auto
result
=
ngraph
::
runtime
::
make_tensor
<
element
::
Bool
>
(
shape
);
(
*
cf
)({
a
},
{
result
});
ASSERT_EQ
((
vector
<
element
::
Bool
::
type
>
{
1
,
2
,
3
,
4
}),
result
->
get_vector
());
}
TEST
(
execute
,
test_convert_float32_bool
)
{
auto
shape
=
Shape
{
2
,
2
};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
rt
=
make_shared
<
TensorViewType
>
(
element
::
Bool
::
element_type
(),
shape
);
auto
f
=
make_shared
<
Function
>
(
make_shared
<
op
::
Convert
>
(
A
,
element
::
Bool
::
element_type
()),
rt
,
op
::
Parameters
{
A
});
auto
external
=
make_shared
<
ngraph
::
runtime
::
ExternalFunction
>
(
f
);
auto
cf
=
external
->
make_call_frame
();
// Create some tensors for input/output
auto
a
=
ngraph
::
runtime
::
make_tensor
<
element
::
Float32
>
(
shape
);
*
a
=
vector
<
element
::
Float32
::
type
>
{
1
,
2
,
3
,
4
};
auto
result
=
ngraph
::
runtime
::
make_tensor
<
element
::
Bool
>
(
shape
);
(
*
cf
)({
a
},
{
result
});
ASSERT_EQ
((
vector
<
element
::
Bool
::
type
>
{
1
,
2
,
3
,
4
}),
result
->
get_vector
());
}
test/type_prop.cpp
View file @
4ecdb791
...
...
@@ -237,9 +237,49 @@ TEST(type_prop, concat_deduce_elem_type_mismatch)
}
}
//
// Tests for dot product.
//
TEST
(
type_prop
,
convert_deduce
)
{
// Deduce type
auto
param
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
Shape
{
2
,
3
,
4
});
auto
c
=
make_shared
<
op
::
Convert
>
(
param
,
element
::
Int32
::
element_type
());
c
->
propagate_types
();
auto
c_vt
=
c
->
get_value_type
();
ASSERT_EQ
(
*
c_vt
,
TensorViewType
(
element
::
Int32
::
element_type
(),
Shape
{
2
,
3
,
4
}));
}
TEST
(
type_prop
,
convert_deduce_correct
)
{
// Check deduced type against incorrectly specified type
auto
param
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
Shape
{
2
,
3
,
4
});
auto
c
=
make_shared
<
op
::
Convert
>
(
param
,
element
::
Int32
::
element_type
());
c
->
set_value_type
(
make_shared
<
TensorViewType
>
(
element
::
Int32
::
element_type
(),
Shape
{
2
,
3
,
4
}));
c
->
propagate_types
();
auto
c_vt
=
c
->
get_value_type
();
ASSERT_EQ
(
*
c_vt
,
TensorViewType
(
element
::
Int32
::
element_type
(),
Shape
{
2
,
3
,
4
}));
}
TEST
(
type_prop
,
convert_deduce_incorrect
)
{
// Check deduced type against incorrectly specified type
auto
param
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
Shape
{
2
,
3
,
4
});
auto
c
=
make_shared
<
op
::
Convert
>
(
param
,
element
::
Int32
::
element_type
());
c
->
set_value_type
(
make_shared
<
TensorViewType
>
(
element
::
Int32
::
element_type
(),
Shape
{
2
,
14
,
4
}));
try
{
c
->
propagate_types
();
// Should have thrown, so fail if it didn't
FAIL
()
<<
"Deduced type should disagree with specified type"
;
}
catch
(
const
ngraph_error
&
error
)
{
EXPECT_EQ
(
error
.
what
(),
std
::
string
(
"Setting value type to a different ValueType"
));
}
catch
(...)
{
FAIL
()
<<
"Deduced type check failed for unexpected reason"
;
}
}
TEST
(
type_prop
,
dot_deduce_scalar_2d
)
{
// Deduce type for scalar/matrix arguments
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment