Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
6a0ac42e
Commit
6a0ac42e
authored
Sep 15, 2017
by
Scott Cyphers
Committed by
GitHub
Sep 15, 2017
Browse files
Options
Browse Files
Download
Plain Diff
Merge pull request #120 from NervanaSystems/cyphers/abc
Function execution
parents
9cee71f2
98f663a6
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
327 additions
and
11 deletions
+327
-11
CMakeLists.txt
src/CMakeLists.txt
+2
-1
function.cpp
src/ngraph/function.cpp
+4
-2
function.hpp
src/ngraph/function.hpp
+2
-1
ngraph.hpp
src/ngraph/ngraph.hpp
+1
-0
call_frame.hpp
src/ngraph/runtime/call_frame.hpp
+3
-4
add.hpp
src/ngraph/runtime/eigen/add.hpp
+1
-0
external_function.cpp
src/ngraph/runtime/eigen/external_function.cpp
+178
-0
external_function.hpp
src/ngraph/runtime/eigen/external_function.hpp
+59
-0
tensor_view.cpp
src/ngraph/runtime/eigen/tensor_view.cpp
+9
-0
tensor_view.hpp
src/ngraph/runtime/eigen/tensor_view.hpp
+3
-0
tensor_view.hpp
src/ngraph/runtime/tensor_view.hpp
+3
-1
type.hpp
src/ngraph/type.hpp
+8
-1
CMakeLists.txt
test/CMakeLists.txt
+1
-0
execute.cpp
test/execute.cpp
+52
-0
type_prop.cpp
test/type_prop.cpp
+1
-1
No files found.
src/CMakeLists.txt
View file @
6a0ac42e
...
@@ -19,6 +19,7 @@ set (SRC
...
@@ -19,6 +19,7 @@ set (SRC
ngraph/descriptor/output.cpp
ngraph/descriptor/output.cpp
ngraph/descriptor/tensor_view.cpp
ngraph/descriptor/tensor_view.cpp
ngraph/descriptor/tensor.cpp
ngraph/descriptor/tensor.cpp
ngraph/function.cpp
ngraph/node.cpp
ngraph/node.cpp
ngraph/shape.cpp
ngraph/shape.cpp
ngraph/pass/assign_tensors.cpp
ngraph/pass/assign_tensors.cpp
...
@@ -31,6 +32,7 @@ set (SRC
...
@@ -31,6 +32,7 @@ set (SRC
ngraph/pass/topological_sort.cpp
ngraph/pass/topological_sort.cpp
ngraph/pass/tree_pass.cpp
ngraph/pass/tree_pass.cpp
ngraph/runtime/call_frame.cpp
ngraph/runtime/call_frame.cpp
ngraph/runtime/eigen/external_function.cpp
ngraph/runtime/eigen/tensor_view.cpp
ngraph/runtime/eigen/tensor_view.cpp
ngraph/shape.cpp
ngraph/shape.cpp
ngraph/visualize.cpp
ngraph/visualize.cpp
...
@@ -42,7 +44,6 @@ set (SRC
...
@@ -42,7 +44,6 @@ set (SRC
ops/constant.cpp
ops/constant.cpp
ops/convert.cpp
ops/convert.cpp
ops/dot.cpp
ops/dot.cpp
ops/function.cpp
ops/op.cpp
ops/op.cpp
ops/parameter.cpp
ops/parameter.cpp
ops/tuple.cpp
ops/tuple.cpp
...
...
src/
ops
/function.cpp
→
src/
ngraph
/function.cpp
View file @
6a0ac42e
...
@@ -12,12 +12,14 @@
...
@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
// ----------------------------------------------------------------------------
#include "ngraph/ngraph.hpp"
#include <memory>
#include "ngraph/function.hpp"
using
namespace
std
;
using
namespace
std
;
using
namespace
ngraph
;
using
namespace
ngraph
;
Function
::
Function
(
const
std
::
shared_ptr
<
Node
>&
result
,
Function
::
Function
(
const
std
::
shared_ptr
<
Node
>&
result
,
const
std
::
vector
<
std
::
shared_ptr
<
op
::
Parameter
>>&
parameters
)
const
std
::
vector
<
std
::
shared_ptr
<
op
::
Parameter
>>&
parameters
)
:
m_result
(
result
)
:
m_result
(
result
)
,
m_parameters
(
parameters
)
,
m_parameters
(
parameters
)
...
...
src/ngraph/function.hpp
View file @
6a0ac42e
...
@@ -14,9 +14,11 @@
...
@@ -14,9 +14,11 @@
#pragma once
#pragma once
#include "ngraph/descriptor/tensor_view.hpp"
#include "ngraph/node.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op.hpp"
#include "ngraph/op.hpp"
#include "ngraph/ops/parameter.hpp"
#include "ngraph/ops/parameter.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/type.hpp"
#include "ngraph/type.hpp"
namespace
ngraph
namespace
ngraph
...
@@ -34,7 +36,6 @@ namespace ngraph
...
@@ -34,7 +36,6 @@ namespace ngraph
return
m_parameters
;
return
m_parameters
;
}
}
std
::
string
get_name
()
const
{
return
m_name
;
}
std
::
string
get_name
()
const
{
return
m_name
;
}
protected
:
protected
:
std
::
shared_ptr
<
Node
>
m_result
;
std
::
shared_ptr
<
Node
>
m_result
;
std
::
vector
<
std
::
shared_ptr
<
ngraph
::
op
::
Parameter
>>
m_parameters
;
std
::
vector
<
std
::
shared_ptr
<
ngraph
::
op
::
Parameter
>>
m_parameters
;
...
...
src/ngraph/ngraph.hpp
View file @
6a0ac42e
...
@@ -56,6 +56,7 @@
...
@@ -56,6 +56,7 @@
#include "ngraph/ops/subtract.hpp"
#include "ngraph/ops/subtract.hpp"
#include "ngraph/ops/tuple.hpp"
#include "ngraph/ops/tuple.hpp"
#include "ngraph/runtime/eigen/add.hpp"
#include "ngraph/runtime/eigen/add.hpp"
#include "ngraph/runtime/eigen/external_function.hpp"
#include "ngraph/runtime/eigen/multiply.hpp"
#include "ngraph/runtime/eigen/multiply.hpp"
#include "ngraph/runtime/eigen/return.hpp"
#include "ngraph/runtime/eigen/return.hpp"
#include "ngraph/runtime/eigen/tensor_view.hpp"
#include "ngraph/runtime/eigen/tensor_view.hpp"
...
...
src/ngraph/runtime/call_frame.hpp
View file @
6a0ac42e
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
#include <memory>
#include <memory>
#include <vector>
#include <vector>
#include "ngraph/runtime/tensor_view.hpp"
#include "ngraph/function.hpp"
#include "ngraph/function.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/instruction.hpp"
...
@@ -24,8 +25,6 @@ namespace ngraph
...
@@ -24,8 +25,6 @@ namespace ngraph
{
{
namespace
runtime
namespace
runtime
{
{
using
PTVs
=
std
::
vector
<
std
::
shared_ptr
<
ngraph
::
runtime
::
PrimaryTensorView
>>
;
class
PrimaryTensorView
;
class
PrimaryTensorView
;
// A VM for executing lightly-compiled graph functions.
// A VM for executing lightly-compiled graph functions.
...
@@ -39,8 +38,8 @@ namespace ngraph
...
@@ -39,8 +38,8 @@ namespace ngraph
size_t
initial_pc
,
size_t
initial_pc
,
const
std
::
shared_ptr
<
std
::
vector
<
std
::
shared_ptr
<
Instruction
>>>&
instructions
);
const
std
::
shared_ptr
<
std
::
vector
<
std
::
shared_ptr
<
Instruction
>>>&
instructions
);
void
operator
()(
const
PTVs
&
inputs
,
const
PTVs
&
outpus
);
void
operator
()(
const
PTVs
&
inputs
,
const
PTVs
&
outpus
);
void
set_return
()
{
m_return
=
true
;
}
void
set_return
()
{
m_return
=
true
;
}
std
::
shared_ptr
<
PrimaryTensorView
>
get_tensor
(
size_t
i
)
{
return
m_tensors
[
i
];
}
std
::
shared_ptr
<
PrimaryTensorView
>
get_tensor
(
size_t
i
)
{
return
m_tensors
[
i
];
}
protected
:
protected
:
...
...
src/ngraph/runtime/eigen/add.hpp
View file @
6a0ac42e
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
#pragma once
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/eigen/tensor_view.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/instruction.hpp"
namespace
ngraph
namespace
ngraph
...
...
src/ngraph/runtime/eigen/external_function.cpp
0 → 100644
View file @
6a0ac42e
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <memory>
#include <tuple>
#include <typeindex>
#include <typeinfo>
#include <unordered_map>
#include "ngraph/descriptor/input.hpp"
#include "ngraph/descriptor/output.hpp"
#include "ngraph/function.hpp"
#include "ngraph/node.hpp"
#include "ngraph/ops/add.hpp"
#include "ngraph/ops/multiply.hpp"
#include "ngraph/pass/topological_sort.hpp"
#include "ngraph/runtime/eigen/add.hpp"
#include "ngraph/runtime/eigen/external_function.hpp"
#include "ngraph/runtime/eigen/multiply.hpp"
#include "ngraph/runtime/eigen/return.hpp"
using
namespace
std
;
using
namespace
ngraph
::
runtime
::
eigen
;
ExternalFunction
::
ExternalFunction
()
:
m_instructions
(
make_shared
<
std
::
vector
<
std
::
shared_ptr
<
ngraph
::
runtime
::
Instruction
>>>
())
{
}
// Define code generators for handled ops.
std
::
unordered_map
<
std
::
type_index
,
std
::
function
<
void
(
ngraph
::
Node
*
,
ExternalFunction
*
,
const
std
::
vector
<
size_t
>&
inputs
,
const
std
::
vector
<
size_t
>&
outputs
)
>>&
ExternalFunction
::
get_op_map
()
{
static
bool
initialized
=
false
;
static
std
::
unordered_map
<
std
::
type_index
,
std
::
function
<
void
(
Node
*
,
ExternalFunction
*
,
const
std
::
vector
<
size_t
>&
inputs
,
const
std
::
vector
<
size_t
>&
outputs
)
>>
op_map
;
if
(
!
initialized
)
{
op_map
[
type_index
(
typeid
(
op
::
Add
))]
=
[](
Node
*
n
,
ExternalFunction
*
ef
,
const
std
::
vector
<
size_t
>&
in
,
const
std
::
vector
<
size_t
>&
out
)
{
ef
->
get_instructions
()
->
push_back
(
make_shared
<
runtime
::
eigen
::
AddInstruction
<
element
::
Float32
>>
(
in
[
0
],
in
[
1
],
out
[
0
]));
};
op_map
[
type_index
(
typeid
(
op
::
Multiply
))]
=
[](
Node
*
n
,
ExternalFunction
*
ef
,
const
std
::
vector
<
size_t
>&
in
,
const
std
::
vector
<
size_t
>&
out
)
{
ef
->
get_instructions
()
->
push_back
(
make_shared
<
runtime
::
eigen
::
MultiplyInstruction
<
element
::
Float32
>>
(
in
[
0
],
in
[
1
],
out
[
0
]));
};
op_map
[
type_index
(
typeid
(
op
::
Parameter
))]
=
[](
Node
*
n
,
ExternalFunction
*
ef
,
const
std
::
vector
<
size_t
>&
in
,
const
std
::
vector
<
size_t
>&
out
)
{};
initialized
=
true
;
}
return
op_map
;
}
void
ExternalFunction
::
compile
(
std
::
shared_ptr
<
ngraph
::
Function
>
f
)
{
// This will be replaced with the pass manager
// Get the ordered list of ops in execution order
pass
::
TopologicalSort
ts
;
ts
.
run_on_tree
(
f
->
get_result
());
auto
nodes
=
ts
.
get_call_graph
();
// Types
for
(
auto
node
:
nodes
)
{
node
->
propagate_types
();
}
// Determine tensors
for
(
auto
node
:
nodes
)
{
node
->
assign_tensors
();
}
// Determine tensor requirements for the call frame
unordered_map
<
shared_ptr
<
ngraph
::
descriptor
::
TensorView
>
,
size_t
>
tensor_index
;
// First come the function inputs
for
(
auto
param
:
f
->
get_parameters
())
{
for
(
auto
output
:
param
->
get_outputs
())
{
auto
tv
=
output
.
get_tensor_view
();
size_t
index
=
tensor_index
.
size
();
tensor_index
[
tv
]
=
index
;
}
}
m_n_inputs
=
tensor_index
.
size
();
// Next are the function outputs
for
(
auto
output
:
f
->
get_result
()
->
get_outputs
())
{
auto
tv
=
output
.
get_tensor_view
();
size_t
index
=
tensor_index
.
size
();
tensor_index
[
tv
]
=
index
;
}
m_n_outputs
=
tensor_index
.
size
()
-
m_n_inputs
;
// All remaining tensor views
for
(
auto
node
:
nodes
)
{
for
(
auto
output
:
node
->
get_outputs
())
{
auto
tv
=
output
.
get_tensor_view
();
if
(
0
==
tensor_index
.
count
(
tv
))
{
size_t
index
=
tensor_index
.
size
();
tensor_index
[
tv
]
=
index
;
m_temp_views
.
push_back
(
tv
);
}
}
}
// Now we build the eigen-VM instructions
auto
op_map
=
get_op_map
();
for
(
auto
node
:
nodes
)
{
auto
handler_it
=
op_map
.
find
(
type_index
(
typeid
(
*
node
)));
if
(
handler_it
==
op_map
.
end
())
{
throw
ngraph_error
(
"Unhandled op during code generation"
);
}
std
::
vector
<
size_t
>
in
;
for
(
auto
input
:
node
->
get_inputs
())
{
auto
output
=
input
.
get_output
();
auto
tv
=
output
.
get_tensor_view
();
in
.
push_back
(
tensor_index
.
at
(
tv
));
}
std
::
vector
<
size_t
>
out
;
for
(
auto
output
:
node
->
get_outputs
())
{
auto
tv
=
output
.
get_tensor_view
();
out
.
push_back
(
tensor_index
.
at
(
tv
));
}
handler_it
->
second
(
node
,
this
,
in
,
out
);
}
m_instructions
->
push_back
(
make_shared
<
runtime
::
eigen
::
ReturnInstruction
>
());
}
shared_ptr
<
ngraph
::
runtime
::
CallFrame
>
ExternalFunction
::
make_call_frame
()
{
std
::
vector
<
std
::
shared_ptr
<
ngraph
::
runtime
::
PrimaryTensorView
>>
temps
;
for
(
auto
tv
:
m_temp_views
)
{
temps
.
push_back
(
ngraph
::
runtime
::
eigen
::
make_tensor_view
(
tv
));
}
return
make_shared
<
ngraph
::
runtime
::
CallFrame
>
(
m_n_inputs
,
m_n_outputs
,
temps
,
0
,
m_instructions
);
}
src/ngraph/runtime/eigen/external_function.hpp
0 → 100644
View file @
6a0ac42e
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <memory>
#include <typeindex>
#include <typeinfo>
#include <unordered_map>
#include "ngraph/function.hpp"
namespace
ngraph
{
namespace
runtime
{
namespace
eigen
{
class
ExternalFunction
{
public
:
ExternalFunction
();
void
compile
(
std
::
shared_ptr
<
ngraph
::
Function
>
f
);
std
::
shared_ptr
<
ngraph
::
runtime
::
CallFrame
>
make_call_frame
();
std
::
shared_ptr
<
std
::
vector
<
std
::
shared_ptr
<
ngraph
::
runtime
::
Instruction
>>>
get_instructions
()
{
return
m_instructions
;
}
protected
:
size_t
m_n_inputs
;
size_t
m_n_outputs
;
std
::
shared_ptr
<
std
::
vector
<
std
::
shared_ptr
<
ngraph
::
runtime
::
Instruction
>>>
m_instructions
;
std
::
vector
<
std
::
shared_ptr
<
ngraph
::
descriptor
::
TensorView
>>
m_temp_views
;
static
std
::
unordered_map
<
std
::
type_index
,
std
::
function
<
void
(
ngraph
::
Node
*
,
ExternalFunction
*
,
const
std
::
vector
<
size_t
>&
inputs
,
const
std
::
vector
<
size_t
>&
outputs
)
>>&
get_op_map
();
};
}
}
}
src/ngraph/runtime/eigen/tensor_view.cpp
View file @
6a0ac42e
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
// ----------------------------------------------------------------------------
// ----------------------------------------------------------------------------
#include <Eigen/Dense>
#include <Eigen/Dense>
#include <memory>
#include "ngraph.hpp"
#include "ngraph.hpp"
...
@@ -27,3 +28,11 @@ template void ngraph::runtime::eigen::add<Float32>(const PrimaryTensorView<Float
...
@@ -27,3 +28,11 @@ template void ngraph::runtime::eigen::add<Float32>(const PrimaryTensorView<Float
template
void
ngraph
::
runtime
::
eigen
::
multiply
<
Float32
>
(
const
PrimaryTensorView
<
Float32
>&
arg0
,
template
void
ngraph
::
runtime
::
eigen
::
multiply
<
Float32
>
(
const
PrimaryTensorView
<
Float32
>&
arg0
,
const
PrimaryTensorView
<
Float32
>&
arg1
,
const
PrimaryTensorView
<
Float32
>&
arg1
,
PrimaryTensorView
<
Float32
>&
out
);
PrimaryTensorView
<
Float32
>&
out
);
std
::
shared_ptr
<
ngraph
::
runtime
::
PrimaryTensorView
>
ngraph
::
runtime
::
eigen
::
make_tensor_view
(
std
::
shared_ptr
<
ngraph
::
descriptor
::
TensorView
>
tensor_view
)
{
// For now, we only support Float32 primary tensor views
return
std
::
make_shared
<
PrimaryTensorView
<
Float32
>>
(
tensor_view
->
get_tensor_view_type
()
->
get_shape
());
}
src/ngraph/runtime/eigen/tensor_view.hpp
View file @
6a0ac42e
...
@@ -19,6 +19,7 @@
...
@@ -19,6 +19,7 @@
#include "ngraph/shape.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/runtime/tensor_view.hpp"
#include "ngraph/runtime/tensor_view.hpp"
#include "ngraph/descriptor/tensor_view.hpp"
namespace
ngraph
namespace
ngraph
{
{
...
@@ -26,6 +27,8 @@ namespace ngraph
...
@@ -26,6 +27,8 @@ namespace ngraph
{
{
namespace
eigen
namespace
eigen
{
{
std
::
shared_ptr
<
ngraph
::
runtime
::
PrimaryTensorView
>
make_tensor_view
(
std
::
shared_ptr
<
ngraph
::
descriptor
::
TensorView
>
);
template
<
typename
ET
>
template
<
typename
ET
>
class
PrimaryTensorView
:
public
ngraph
::
runtime
::
PrimaryTensorView
class
PrimaryTensorView
:
public
ngraph
::
runtime
::
PrimaryTensorView
{
{
...
...
src/ngraph/runtime/tensor_view.hpp
View file @
6a0ac42e
...
@@ -22,7 +22,9 @@ namespace ngraph
...
@@ -22,7 +22,9 @@ namespace ngraph
class
PrimaryTensorView
class
PrimaryTensorView
{
{
public
:
public
:
virtual
~
PrimaryTensorView
(){}
virtual
~
PrimaryTensorView
()
{}
};
};
using
PTVs
=
std
::
vector
<
std
::
shared_ptr
<
ngraph
::
runtime
::
PrimaryTensorView
>>
;
}
}
}
}
src/ngraph/type.hpp
View file @
6a0ac42e
...
@@ -30,6 +30,12 @@ namespace ngraph
...
@@ -30,6 +30,12 @@ namespace ngraph
/// | TupleType(ValueType[])
/// | TupleType(ValueType[])
class
ValueType
class
ValueType
{
{
ValueType
(
const
ValueType
&
)
=
delete
;
ValueType
&
operator
=
(
const
ValueType
&
)
=
delete
;
protected
:
ValueType
()
{}
public
:
public
:
virtual
~
ValueType
()
{}
virtual
~
ValueType
()
{}
virtual
bool
operator
==
(
const
ValueType
&
that
)
const
=
0
;
virtual
bool
operator
==
(
const
ValueType
&
that
)
const
=
0
;
...
@@ -48,7 +54,8 @@ namespace ngraph
...
@@ -48,7 +54,8 @@ namespace ngraph
/// /param element_type The type of the tensor elements.
/// /param element_type The type of the tensor elements.
/// /param shape The shape of the tensor.
/// /param shape The shape of the tensor.
TensorViewType
(
const
element
::
Type
&
element_type
,
const
Shape
&
shape
)
TensorViewType
(
const
element
::
Type
&
element_type
,
const
Shape
&
shape
)
:
m_element_type
(
element_type
)
:
ValueType
()
,
m_element_type
(
element_type
)
,
m_shape
(
shape
)
,
m_shape
(
shape
)
{
{
}
}
...
...
test/CMakeLists.txt
View file @
6a0ac42e
...
@@ -25,6 +25,7 @@ set (SRC
...
@@ -25,6 +25,7 @@ set (SRC
build_graph.cpp
build_graph.cpp
eigen.cpp
eigen.cpp
element_type.cpp
element_type.cpp
execute.cpp
input_output_assign.cpp
input_output_assign.cpp
main.cpp
main.cpp
op.cpp
op.cpp
...
...
test/execute.cpp
0 → 100644
View file @
6a0ac42e
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
using
namespace
std
;
using
namespace
ngraph
;
TEST
(
execute
,
test_abc
)
{
auto
shape
=
Shape
{
2
,
2
};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
B
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
C
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
f
=
make_shared
<
Function
>
(
make_shared
<
op
::
Multiply
>
(
make_shared
<
op
::
Add
>
(
A
,
B
),
C
),
op
::
Parameters
{
A
,
B
,
C
});
auto
external
=
make_shared
<
ngraph
::
runtime
::
eigen
::
ExternalFunction
>
();
external
->
compile
(
f
);
auto
cf
=
external
->
make_call_frame
();
// Create some tensors for input/output
auto
a
=
make_shared
<
runtime
::
eigen
::
PrimaryTensorView
<
element
::
Float32
>>
(
shape
);
*
a
=
vector
<
float
>
{
1
,
2
,
3
,
4
};
auto
b
=
make_shared
<
runtime
::
eigen
::
PrimaryTensorView
<
element
::
Float32
>>
(
shape
);
*
b
=
vector
<
float
>
{
5
,
6
,
7
,
8
};
auto
c
=
make_shared
<
runtime
::
eigen
::
PrimaryTensorView
<
element
::
Float32
>>
(
shape
);
*
c
=
vector
<
float
>
{
9
,
10
,
11
,
12
};
auto
result
=
make_shared
<
runtime
::
eigen
::
PrimaryTensorView
<
element
::
Float32
>>
(
shape
);
(
*
cf
)(
runtime
::
PTVs
{
a
,
b
,
c
},
runtime
::
PTVs
{
result
});
ASSERT_EQ
((
vector
<
float
>
{
54
,
80
,
110
,
144
}),
result
->
get_vector
());
(
*
cf
)(
runtime
::
PTVs
{
b
,
a
,
c
},
runtime
::
PTVs
{
result
});
ASSERT_EQ
((
vector
<
float
>
{
54
,
80
,
110
,
144
}),
result
->
get_vector
());
(
*
cf
)(
runtime
::
PTVs
{
a
,
c
,
b
},
runtime
::
PTVs
{
result
});
ASSERT_EQ
((
vector
<
float
>
{
50
,
72
,
98
,
128
}),
result
->
get_vector
());
}
test/type_prop.cpp
View file @
6a0ac42e
...
@@ -351,7 +351,7 @@ TEST(type_prop, comparison_good)
...
@@ -351,7 +351,7 @@ TEST(type_prop, comparison_good)
auto
tv0_2_4_param_1
=
make_shared
<
op
::
Parameter
>
(
auto
tv0_2_4_param_1
=
make_shared
<
op
::
Parameter
>
(
make_shared
<
TensorViewType
>
(
element
::
Float32
::
element_type
(),
Shape
{
2
,
4
}));
make_shared
<
TensorViewType
>
(
element
::
Float32
::
element_type
(),
Shape
{
2
,
4
}));
auto
eq
=
make_shared
<
op
::
Equal
>
(
tv0_2_4_param_0
,
tv0_2_4_param_1
);
auto
eq
=
make_shared
<
op
::
Equal
>
(
tv0_2_4_param_0
,
tv0_2_4_param_1
);
auto
expected_type
=
TensorViewType
(
element
::
Bool
::
element_type
(),
Shape
{
2
,
4
})
;
TensorViewType
expected_type
{
element
::
Bool
::
element_type
(),
Shape
{
2
,
4
}}
;
eq
->
propagate_types
();
eq
->
propagate_types
();
EXPECT_EQ
(
*
eq
->
get_value_type
(),
expected_type
);
EXPECT_EQ
(
*
eq
->
get_value_type
(),
expected_type
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment