Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
4442777f
Commit
4442777f
authored
Sep 11, 2017
by
Adam Procter
Browse files
Options
Browse Files
Download
Plain Diff
Merge remote-tracking branch 'origin/master' into aprocter/doxygen
parents
73704ed0
f1608316
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
577 additions
and
24 deletions
+577
-24
CMakeLists.txt
src/CMakeLists.txt
+4
-0
input.cpp
src/ngraph/descriptor/input.cpp
+30
-0
input.hpp
src/ngraph/descriptor/input.hpp
+57
-0
output.cpp
src/ngraph/descriptor/output.cpp
+32
-0
output.hpp
src/ngraph/descriptor/output.hpp
+51
-0
tensor.cpp
src/ngraph/descriptor/tensor.cpp
+24
-0
tensor.hpp
src/ngraph/descriptor/tensor.hpp
+46
-0
tensor_view.cpp
src/ngraph/descriptor/tensor_view.cpp
+28
-0
tensor_view.hpp
src/ngraph/descriptor/tensor_view.hpp
+82
-0
tensor_view_layout.hpp
src/ngraph/descriptor/tensor_view_layout.hpp
+31
-0
ngraph.hpp
src/ngraph/ngraph.hpp
+5
-0
node.cpp
src/ngraph/node.cpp
+48
-17
node.hpp
src/ngraph/node.hpp
+15
-1
type.hpp
src/ngraph/type.hpp
+6
-1
type.cpp
src/types/type.cpp
+12
-0
CMakeLists.txt
test/CMakeLists.txt
+6
-5
input_output_assign.cpp
test/input_output_assign.cpp
+100
-0
No files found.
src/CMakeLists.txt
View file @
4442777f
...
@@ -15,6 +15,10 @@ set (SRC
...
@@ -15,6 +15,10 @@ set (SRC
tree.cpp
tree.cpp
util.cpp
util.cpp
log.cpp
log.cpp
ngraph/descriptor/input.cpp
ngraph/descriptor/output.cpp
ngraph/descriptor/tensor.cpp
ngraph/descriptor/tensor_view.cpp
ops/binary_elementwise_builtin.cpp
ops/binary_elementwise_builtin.cpp
ops/broadcast.cpp
ops/broadcast.cpp
ops/concatenate.cpp
ops/concatenate.cpp
...
...
src/ngraph/descriptor/input.cpp
0 → 100644
View file @
4442777f
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph.hpp"
using
namespace
std
;
using
namespace
ngraph
;
using
namespace
descriptor
;
Input
::
Input
(
Node
*
node
,
size_t
index
,
size_t
argno
,
size_t
arg_index
,
const
shared_ptr
<
Output
>&
output
)
:
m_node
(
node
)
,
m_index
(
index
)
,
m_argno
(
argno
)
,
m_arg_index
(
arg_index
)
,
m_output
(
output
)
{
output
->
add_input
(
this
);
}
src/ngraph/descriptor/input.hpp
0 → 100644
View file @
4442777f
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <memory>
namespace
ngraph
{
namespace
descriptor
{
class
Output
;
// Describes a tensor that is an input to an op, directly or indirectly via a tuple
class
Input
:
public
std
::
enable_shared_from_this
<
Input
>
{
Input
(
const
Input
&
)
=
delete
;
Input
&
operator
=
(
const
Input
&
)
=
delete
;
public
:
/// @param node The node that owns this input; not shared to prevent owner loop
/// @param index The position of this this tensor in all input tensors
/// @param argno The position of the argument with this tensor
/// @param arg_index The position of the tensor within the argument's tensors
/// @param output The output that supplies a value for this input
Input
(
Node
*
node
,
size_t
index
,
size_t
argno
,
size_t
arg_index
,
const
std
::
shared_ptr
<
Output
>&
output
);
std
::
shared_ptr
<
Node
>
get_node
()
{
return
m_node
->
shared_from_this
();
}
size_t
get_argno
()
const
{
return
m_argno
;
}
size_t
get_arg_index
()
const
{
return
m_arg_index
;
}
size_t
get_index
()
const
{
return
m_index
;
}
std
::
shared_ptr
<
Output
>
get_output
()
const
{
return
m_output
;
}
protected
:
Node
*
m_node
;
// The node we are an input for
size_t
m_index
;
// Index into all input tensors
size_t
m_argno
;
// Arg number for this input
size_t
m_arg_index
;
// Index into arg's tensors
std
::
shared_ptr
<
Output
>
m_output
;
};
}
}
src/ngraph/descriptor/output.cpp
0 → 100644
View file @
4442777f
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph.hpp"
using
namespace
std
;
using
namespace
ngraph
;
using
namespace
descriptor
;
Output
::
Output
(
Node
*
node
,
size_t
index
,
const
std
::
shared_ptr
<
TensorView
>&
tensor_view
)
:
m_node
(
node
)
,
m_index
(
index
)
,
m_tensor_view
(
tensor_view
)
{
}
// Add an input to the vector of inputs that use this output.
void
Output
::
add_input
(
Input
*
input
)
{
m_inputs
.
insert
(
input
);
}
src/ngraph/descriptor/output.hpp
0 → 100644
View file @
4442777f
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <memory>
#include <set>
#include "descriptor/tensor_view.hpp"
namespace
ngraph
{
namespace
descriptor
{
// Describes an output tensor of an op
class
Output
:
public
std
::
enable_shared_from_this
<
Output
>
{
Output
(
const
Output
&
)
=
delete
;
Output
&
operator
=
(
const
Output
&
)
=
delete
;
public
:
/// @param node Node that owns this output. Not shared to prevent owner loop.
/// @param index Position of the output tensor in all output tensors
/// @param tensor_view The view of this tensor; where the value will be written
Output
(
Node
*
node
,
size_t
index
,
const
std
::
shared_ptr
<
TensorView
>&
tensor_view
);
std
::
shared_ptr
<
Node
>
get_node
()
const
{
return
m_node
->
shared_from_this
();
}
size_t
get_index
()
const
{
return
m_index
;
}
std
::
shared_ptr
<
TensorView
>
get_tensor_view
()
const
{
return
m_tensor_view
;
}
void
add_input
(
Input
*
input
);
const
std
::
set
<
Input
*>&
get_inputs
()
const
{
return
m_inputs
;
}
protected
:
Node
*
m_node
;
size_t
m_index
;
std
::
shared_ptr
<
TensorView
>
m_tensor_view
;
std
::
set
<
Input
*>
m_inputs
;
};
}
}
src/ngraph/descriptor/tensor.cpp
0 → 100644
View file @
4442777f
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "descriptor/tensor.hpp"
using
namespace
ngraph
;
using
namespace
descriptor
;
Tensor
::
Tensor
(
const
element
::
Type
&
element_type
,
PrimaryTensorView
*
primary_tensor_view
)
:
m_element_type
(
element_type
)
,
m_primary_tensor_view
(
primary_tensor_view
)
{
}
src/ngraph/descriptor/tensor.hpp
0 → 100644
View file @
4442777f
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <memory>
#include <vector>
namespace
ngraph
{
namespace
element
{
class
Type
;
}
namespace
descriptor
{
class
TensorView
;
class
PrimaryTensorView
;
class
Tensor
{
friend
class
PrimaryTensorView
;
Tensor
(
const
Tensor
&
)
=
delete
;
Tensor
&
operator
=
(
const
Tensor
&
)
=
delete
;
Tensor
(
const
element
::
Type
&
element_type
,
PrimaryTensorView
*
tensor_view
);
protected
:
const
element
::
Type
&
m_element_type
;
PrimaryTensorView
*
m_primary_tensor_view
;
};
}
}
src/ngraph/descriptor/tensor_view.cpp
0 → 100644
View file @
4442777f
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "descriptor/tensor_view.hpp"
using
namespace
ngraph
;
using
namespace
descriptor
;
const
Tensor
&
PrimaryTensorView
::
get_tensor
()
const
{
return
m_tensor
;
}
Tensor
&
PrimaryTensorView
::
get_tensor
()
{
return
m_tensor
;
}
src/ngraph/descriptor/tensor_view.hpp
0 → 100644
View file @
4442777f
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "descriptor/tensor.hpp"
#include "shape.hpp"
#include "type.hpp"
namespace
ngraph
{
namespace
descriptor
{
class
Tensor
;
class
TensorViewLayout
;
// Describes a view of an instantiated tensor
class
TensorView
:
public
std
::
enable_shared_from_this
<
TensorView
>
{
TensorView
(
const
TensorView
&
)
=
delete
;
TensorView
&
operator
=
(
const
TensorView
&
)
=
delete
;
protected
:
TensorView
(
const
std
::
shared_ptr
<
const
TensorViewType
>&
tensor_view_type
)
:
m_tensor_view_type
(
tensor_view_type
)
{
}
public
:
virtual
~
TensorView
()
{}
virtual
const
Tensor
&
get_tensor
()
const
=
0
;
virtual
Tensor
&
get_tensor
()
=
0
;
std
::
shared_ptr
<
const
TensorViewType
>
get_tensor_view_type
()
const
{
return
m_tensor_view_type
;
}
const
std
::
shared_ptr
<
TensorViewLayout
>&
get_tensor_view_layout
()
const
{
return
m_tensor_view_layout
;
}
void
set_tensor_view_layout
(
const
std
::
shared_ptr
<
TensorViewLayout
>&
tensor_view_layout
)
{
m_tensor_view_layout
=
tensor_view_layout
;
}
protected
:
std
::
shared_ptr
<
const
TensorViewType
>
m_tensor_view_type
;
std
::
shared_ptr
<
TensorViewLayout
>
m_tensor_view_layout
;
};
// A PrimaryTensorView owns the tensor. All other views are the result
// of some index operation on the primary view.
class
PrimaryTensorView
:
public
TensorView
{
public
:
PrimaryTensorView
(
const
std
::
shared_ptr
<
const
TensorViewType
>&
tensor_view_type
)
:
TensorView
(
tensor_view_type
)
,
m_tensor
(
tensor_view_type
->
get_element_type
(),
this
)
{
}
virtual
const
Tensor
&
get_tensor
()
const
override
;
virtual
Tensor
&
get_tensor
()
override
;
protected
:
Tensor
m_tensor
;
};
}
}
src/ngraph/descriptor/tensor_view_layout.hpp
0 → 100644
View file @
4442777f
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <vector>
namespace
ngraph
{
namespace
descriptor
{
using
Strides
=
std
::
vector
<
size_t
>
;
class
TensorViewLayout
{
protected
:
Strides
m_strides
;
};
}
}
src/ngraph/ngraph.hpp
View file @
4442777f
...
@@ -23,6 +23,11 @@
...
@@ -23,6 +23,11 @@
#include "except.hpp"
#include "except.hpp"
#include "function.hpp"
#include "function.hpp"
#include "node.hpp"
#include "node.hpp"
#include "descriptor/input.hpp"
#include "descriptor/output.hpp"
#include "descriptor/tensor_view.hpp"
#include "descriptor/tensor_view_layout.hpp"
#include "descriptor/tensor.hpp"
#include "op.hpp"
#include "op.hpp"
#include "ops/add.hpp"
#include "ops/add.hpp"
#include "ops/broadcast.hpp"
#include "ops/broadcast.hpp"
...
...
src/ngraph/node.cpp
View file @
4442777f
...
@@ -12,13 +12,14 @@
...
@@ -12,13 +12,14 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
// ----------------------------------------------------------------------------
#include "node.hpp"
#include "ngraph.hpp"
#include "op.hpp"
size_t
ngraph
::
Node
::
m_next_instance_id
=
0
;
using
namespace
std
;
using
namespace
ngraph
;
ngraph
::
Node
::
Node
(
const
std
::
vector
<
std
::
shared_ptr
<
Node
>>&
arguments
,
size_t
Node
::
m_next_instance_id
=
0
;
std
::
shared_ptr
<
ValueType
>
value_type
)
Node
::
Node
(
const
std
::
vector
<
shared_ptr
<
Node
>>&
arguments
,
shared_ptr
<
ValueType
>
value_type
)
:
m_arguments
(
arguments
)
:
m_arguments
(
arguments
)
,
m_value_type
(
value_type
)
,
m_value_type
(
value_type
)
,
m_instance_id
(
m_next_instance_id
++
)
,
m_instance_id
(
m_next_instance_id
++
)
...
@@ -30,33 +31,63 @@ ngraph::Node::Node(const std::vector<std::shared_ptr<Node>>& arguments,
...
@@ -30,33 +31,63 @@ ngraph::Node::Node(const std::vector<std::shared_ptr<Node>>& arguments,
}
}
}
}
void
ngraph
::
Node
::
set_value_type_checked
(
const
std
::
shared_ptr
<
ValueType
>&
value_type
)
void
Node
::
set_value_type_checked
(
const
shared_ptr
<
ValueType
>&
value_type
)
{
{
if
(
nullptr
==
m_value_type
){
if
(
nullptr
==
m_value_type
)
{
m_value_type
=
value_type
;
m_value_type
=
value_type
;
}
else
{
}
if
(
*
m_value_type
!=
*
value_type
){
else
throw
ngraph
::
ngraph_error
(
"Setting value type to a different ValueType"
);
{
if
(
*
m_value_type
!=
*
value_type
)
{
throw
ngraph_error
(
"Setting value type to a different ValueType"
);
}
}
}
void
Node
::
assign_tensors
()
{
vector
<
std
::
shared_ptr
<
const
TensorViewType
>>
tensor_view_types
;
get_value_type
()
->
collect_tensor_views
(
tensor_view_types
);
size_t
i
=
0
;
for
(
auto
tvt
:
tensor_view_types
)
{
auto
tensor_view_descriptor
=
make_shared
<
descriptor
::
PrimaryTensorView
>
(
tvt
);
auto
output
=
make_shared
<
descriptor
::
Output
>
(
this
,
i
++
,
tensor_view_descriptor
);
m_outputs
.
push_back
(
output
);
}
i
=
0
;
size_t
argno
=
0
;
for
(
auto
arg
:
get_arguments
())
{
size_t
arg_index
=
0
;
for
(
auto
output
:
arg
->
get_outputs
())
{
auto
input
=
make_shared
<
descriptor
::
Input
>
(
this
,
i
++
,
argno
,
arg_index
++
,
output
);
m_inputs
.
push_back
(
input
);
}
}
argno
++
;
}
}
}
}
bool
ngraph
::
Node
::
is_op
()
const
bool
Node
::
is_op
()
const
{
{
return
dynamic_cast
<
const
ngraph
::
Op
*>
(
this
)
!=
nullptr
;
return
dynamic_cast
<
const
Op
*>
(
this
)
!=
nullptr
;
}
}
bool
ngraph
::
Node
::
is_parameter
()
const
bool
Node
::
is_parameter
()
const
{
{
return
dynamic_cast
<
const
ngraph
::
op
::
Parameter
*>
(
this
)
!=
nullptr
;
return
dynamic_cast
<
const
op
::
Parameter
*>
(
this
)
!=
nullptr
;
}
}
namespace
ngraph
namespace
ngraph
{
{
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
ngraph
::
Node
&
node
)
ostream
&
operator
<<
(
ostream
&
out
,
const
Node
&
node
)
{
{
auto
op_tmp
=
dynamic_cast
<
const
ngraph
::
Op
*>
(
&
node
);
auto
op_tmp
=
dynamic_cast
<
const
Op
*>
(
&
node
);
auto
parameter_tmp
=
dynamic_cast
<
const
ngraph
::
Op
*>
(
&
node
);
auto
parameter_tmp
=
dynamic_cast
<
const
Op
*>
(
&
node
);
if
(
op_tmp
)
if
(
op_tmp
)
{
{
out
<<
"Op("
<<
op_tmp
->
get_node_id
()
<<
")"
;
out
<<
"Op("
<<
op_tmp
->
get_node_id
()
<<
")"
;
...
...
src/ngraph/node.hpp
View file @
4442777f
...
@@ -27,6 +27,12 @@ namespace ngraph
...
@@ -27,6 +27,12 @@ namespace ngraph
{
{
class
Op
;
class
Op
;
namespace
descriptor
{
class
Input
;
class
Output
;
}
/// Nodes are the backbone of the graph of Value dataflow. Every node has
/// Nodes are the backbone of the graph of Value dataflow. Every node has
/// zero or more nodes as arguments and one value, which is either a tensor
/// zero or more nodes as arguments and one value, which is either a tensor
/// view or a (possibly empty) tuple of values.
/// view or a (possibly empty) tuple of values.
...
@@ -53,6 +59,10 @@ namespace ngraph
...
@@ -53,6 +59,10 @@ namespace ngraph
/// Propagate types and check arguments for consistency
/// Propagate types and check arguments for consistency
virtual
void
propagate_types
()
=
0
;
virtual
void
propagate_types
()
=
0
;
/// Assign Input and Output vectors
// This might later need to be virtual.
void
assign_tensors
();
const
Nodes
&
get_arguments
()
const
{
return
m_arguments
;
}
const
Nodes
&
get_arguments
()
const
{
return
m_arguments
;
}
const
std
::
multiset
<
Node
*>&
users
()
const
{
return
m_users
;
}
const
std
::
multiset
<
Node
*>&
users
()
const
{
return
m_users
;
}
...
@@ -94,7 +104,9 @@ namespace ngraph
...
@@ -94,7 +104,9 @@ namespace ngraph
size_t
get_instance_id
()
const
{
return
m_instance_id
;
}
size_t
get_instance_id
()
const
{
return
m_instance_id
;
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
,
const
Node
&
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
,
const
Node
&
);
std
::
vector
<
std
::
shared_ptr
<
descriptor
::
Input
>>
get_inputs
()
{
return
m_inputs
;
}
std
::
vector
<
std
::
shared_ptr
<
descriptor
::
Output
>>
get_outputs
()
{
return
m_outputs
;
}
protected
:
protected
:
Nodes
m_arguments
;
Nodes
m_arguments
;
...
@@ -103,5 +115,7 @@ namespace ngraph
...
@@ -103,5 +115,7 @@ namespace ngraph
std
::
string
m_name
;
std
::
string
m_name
;
size_t
m_instance_id
;
size_t
m_instance_id
;
static
size_t
m_next_instance_id
;
static
size_t
m_next_instance_id
;
std
::
vector
<
std
::
shared_ptr
<
descriptor
::
Input
>>
m_inputs
;
std
::
vector
<
std
::
shared_ptr
<
descriptor
::
Output
>>
m_outputs
;
};
};
}
}
src/ngraph/type.hpp
View file @
4442777f
...
@@ -34,10 +34,13 @@ namespace ngraph
...
@@ -34,10 +34,13 @@ namespace ngraph
virtual
~
ValueType
()
{}
virtual
~
ValueType
()
{}
virtual
bool
operator
==
(
const
ValueType
&
that
)
const
=
0
;
virtual
bool
operator
==
(
const
ValueType
&
that
)
const
=
0
;
bool
operator
!=
(
const
ValueType
&
that
)
const
{
return
!
(
*
this
==
that
);
}
bool
operator
!=
(
const
ValueType
&
that
)
const
{
return
!
(
*
this
==
that
);
}
/// Add tensor views in depth-first order.
virtual
void
collect_tensor_views
(
std
::
vector
<
std
::
shared_ptr
<
const
TensorViewType
>>&
views
)
const
=
0
;
};
};
/// Describes a tensor view; an element type and a shape.
/// Describes a tensor view; an element type and a shape.
class
TensorViewType
:
public
ValueType
class
TensorViewType
:
public
ValueType
,
public
std
::
enable_shared_from_this
<
TensorViewType
>
{
{
public
:
public
:
/// /param element_type The type of the tensor elements.
/// /param element_type The type of the tensor elements.
...
@@ -52,6 +55,7 @@ namespace ngraph
...
@@ -52,6 +55,7 @@ namespace ngraph
const
Shape
&
get_shape
()
const
{
return
m_shape
;
}
const
Shape
&
get_shape
()
const
{
return
m_shape
;
}
virtual
bool
operator
==
(
const
ValueType
&
that
)
const
override
;
virtual
bool
operator
==
(
const
ValueType
&
that
)
const
override
;
virtual
void
collect_tensor_views
(
std
::
vector
<
std
::
shared_ptr
<
const
TensorViewType
>>&
views
)
const
override
;
protected
:
protected
:
const
element
::
Type
&
m_element_type
;
const
element
::
Type
&
m_element_type
;
...
@@ -78,6 +82,7 @@ namespace ngraph
...
@@ -78,6 +82,7 @@ namespace ngraph
std
::
vector
<
std
::
shared_ptr
<
ValueType
>>
set_element_types
()
{
return
m_element_types
;
}
std
::
vector
<
std
::
shared_ptr
<
ValueType
>>
set_element_types
()
{
return
m_element_types
;
}
virtual
bool
operator
==
(
const
ValueType
&
that
)
const
override
;
virtual
bool
operator
==
(
const
ValueType
&
that
)
const
override
;
virtual
void
collect_tensor_views
(
std
::
vector
<
std
::
shared_ptr
<
const
TensorViewType
>>&
views
)
const
override
;
protected
:
protected
:
std
::
vector
<
std
::
shared_ptr
<
ValueType
>>
m_element_types
;
std
::
vector
<
std
::
shared_ptr
<
ValueType
>>
m_element_types
;
...
...
src/types/type.cpp
View file @
4442777f
...
@@ -37,6 +37,11 @@ bool TensorViewType::operator==(const ValueType& that) const
...
@@ -37,6 +37,11 @@ bool TensorViewType::operator==(const ValueType& that) const
return
true
;
return
true
;
}
}
void
TensorViewType
::
collect_tensor_views
(
std
::
vector
<
std
::
shared_ptr
<
const
TensorViewType
>>&
views
)
const
{
views
.
push_back
(
shared_from_this
());
}
bool
TupleType
::
operator
==
(
const
ValueType
&
that
)
const
bool
TupleType
::
operator
==
(
const
ValueType
&
that
)
const
{
{
auto
that_tvt
=
dynamic_cast
<
const
TupleType
*>
(
&
that
);
auto
that_tvt
=
dynamic_cast
<
const
TupleType
*>
(
&
that
);
...
@@ -46,3 +51,10 @@ bool TupleType::operator==(const ValueType& that) const
...
@@ -46,3 +51,10 @@ bool TupleType::operator==(const ValueType& that) const
}
}
return
that_tvt
->
get_element_types
()
==
get_element_types
();
return
that_tvt
->
get_element_types
()
==
get_element_types
();
}
}
void
TupleType
::
collect_tensor_views
(
std
::
vector
<
std
::
shared_ptr
<
const
TensorViewType
>>&
views
)
const
{
for
(
auto
elt
:
m_element_types
){
elt
->
collect_tensor_views
(
views
);
}
}
test/CMakeLists.txt
View file @
4442777f
...
@@ -24,14 +24,15 @@ include_directories(
...
@@ -24,14 +24,15 @@ include_directories(
set
(
SRC
set
(
SRC
main.cpp
main.cpp
build_graph.cpp
build_graph.cpp
util.cpp
eigen.cpp
tensor.cpp
element_type.cpp
element_type.cpp
uuid.cpp
op.cpp
input_output_assign.cpp
tensor.cpp
topological_sort.cpp
topological_sort.cpp
type_prop.cpp
type_prop.cpp
op
.cpp
util
.cpp
eigen
.cpp
uuid
.cpp
)
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-std=c++11"
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-std=c++11"
)
...
...
test/input_output_assign.cpp
0 → 100644
View file @
4442777f
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include <memory>
using
namespace
std
;
using
namespace
ngraph
;
TEST
(
input_output
,
param_tensor
)
{
// Params have no arguments, so we can check that the value becomes a tensor output
auto
tv_tp
=
make_shared
<
TensorViewType
>
(
element
::
Float32
::
element_type
(),
Shape
{
2
,
4
});
auto
param
=
make_shared
<
op
::
Parameter
>
(
tv_tp
);
param
->
propagate_types
();
param
->
assign_tensors
();
ASSERT_EQ
(
param
->
get_outputs
().
size
(),
1
);
for
(
size_t
i
=
0
;
i
<
param
->
get_outputs
().
size
();
i
++
)
{
auto
output
=
param
->
get_outputs
()[
i
];
ASSERT_EQ
(
i
,
output
->
get_index
());
ASSERT_EQ
(
param
,
output
->
get_node
());
}
ASSERT_EQ
(
*
tv_tp
,
*
param
->
get_outputs
()[
0
]
->
get_tensor_view
()
->
get_tensor_view_type
());
}
TEST
(
input_output
,
param_tuple
)
{
// Same as param_tensor, but for a tuple
auto
tv_tp_0
=
make_shared
<
TensorViewType
>
(
element
::
Float32
::
element_type
(),
Shape
{
2
,
4
});
auto
tv_tp_1
=
make_shared
<
TensorViewType
>
(
element
::
Float32
::
element_type
(),
Shape
{
2
,
4
,
6
});
auto
tp_tp
=
make_shared
<
TupleType
>
(
std
::
vector
<
std
::
shared_ptr
<
ValueType
>>
{
tv_tp_0
,
tv_tp_1
});
auto
param
=
make_shared
<
op
::
Parameter
>
(
tp_tp
);
param
->
propagate_types
();
param
->
assign_tensors
();
ASSERT_EQ
(
param
->
get_outputs
().
size
(),
2
);
for
(
size_t
i
=
0
;
i
<
param
->
get_outputs
().
size
();
i
++
)
{
auto
output
=
param
->
get_outputs
()[
i
];
ASSERT_EQ
(
i
,
output
->
get_index
());
ASSERT_EQ
(
param
,
output
->
get_node
());
}
ASSERT_EQ
(
*
tv_tp_0
,
*
param
->
get_outputs
()[
0
]
->
get_tensor_view
()
->
get_tensor_view_type
());
ASSERT_EQ
(
*
tv_tp_1
,
*
param
->
get_outputs
()[
1
]
->
get_tensor_view
()
->
get_tensor_view_type
());
}
TEST
(
input_output
,
simple_output
)
{
auto
tv_tp_0
=
make_shared
<
TensorViewType
>
(
element
::
Float32
::
element_type
(),
Shape
{
2
,
4
});
auto
param_0
=
make_shared
<
op
::
Parameter
>
(
tv_tp_0
);
auto
param_1
=
make_shared
<
op
::
Parameter
>
(
tv_tp_0
);
auto
add
=
make_shared
<
op
::
Add
>
(
param_0
,
param_1
);
// Sort the ops
vector
<
shared_ptr
<
Node
>>
nodes
;
nodes
.
push_back
(
param_0
);
nodes
.
push_back
(
param_1
);
nodes
.
push_back
(
add
);
// Type info
for
(
auto
node
:
nodes
)
{
node
->
propagate_types
();
}
// Add inputs/outputs
for
(
auto
node
:
nodes
)
{
node
->
assign_tensors
();
}
// At this point, the add should have each input associated with the output of the appropriate parameter
auto
inputs
=
add
->
get_inputs
();
ASSERT_EQ
(
2
,
inputs
.
size
());
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
i
++
)
{
auto
input
=
inputs
[
i
];
ASSERT_EQ
(
i
,
input
->
get_index
());
ASSERT_EQ
(
i
,
input
->
get_argno
());
ASSERT_EQ
(
0
,
input
->
get_arg_index
());
ASSERT_EQ
(
input
->
get_output
()
->
get_node
(),
add
->
get_arguments
()[
i
]);
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment