Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
d87b0065
Unverified
Commit
d87b0065
authored
Jan 24, 2018
by
Scott Cyphers
Committed by
GitHub
Jan 24, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Remove TupleType, ValueType (#411)
* Remove TupleType, ValueType * Fix compile error.
parent
f6c6daef
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
16 additions
and
159 deletions
+16
-159
tensor_view.cpp
src/ngraph/descriptor/tensor_view.cpp
+1
-1
tensor_view.hpp
src/ngraph/descriptor/tensor_view.hpp
+1
-1
node.cpp
src/ngraph/node.cpp
+1
-1
node.hpp
src/ngraph/node.hpp
+1
-1
type.cpp
src/ngraph/types/type.cpp
+2
-74
type.hpp
src/ngraph/types/type.hpp
+10
-63
build_graph.cpp
test/build_graph.cpp
+0
-18
No files found.
src/ngraph/descriptor/tensor_view.cpp
View file @
d87b0065
...
@@ -18,7 +18,7 @@
...
@@ -18,7 +18,7 @@
using
namespace
ngraph
;
using
namespace
ngraph
;
using
namespace
std
;
using
namespace
std
;
shared_ptr
<
const
ngraph
::
Value
Type
>
descriptor
::
TensorView
::
get_value_type
()
const
shared_ptr
<
const
ngraph
::
TensorView
Type
>
descriptor
::
TensorView
::
get_value_type
()
const
{
{
return
m_tensor_view_type
;
return
m_tensor_view_type
;
}
}
src/ngraph/descriptor/tensor_view.hpp
View file @
d87b0065
...
@@ -51,7 +51,7 @@ namespace ngraph
...
@@ -51,7 +51,7 @@ namespace ngraph
virtual
const
Tensor
&
get_tensor
()
const
=
0
;
virtual
const
Tensor
&
get_tensor
()
const
=
0
;
virtual
Tensor
&
get_tensor
()
=
0
;
virtual
Tensor
&
get_tensor
()
=
0
;
virtual
std
::
shared_ptr
<
const
Value
Type
>
get_value_type
()
const
;
virtual
std
::
shared_ptr
<
const
TensorView
Type
>
get_value_type
()
const
;
const
std
::
string
&
get_name
()
const
{
return
m_name
;
}
const
std
::
string
&
get_name
()
const
{
return
m_name
;
}
std
::
shared_ptr
<
const
TensorViewType
>
get_tensor_view_type
()
const
std
::
shared_ptr
<
const
TensorViewType
>
get_tensor_view_type
()
const
...
...
src/ngraph/node.cpp
View file @
d87b0065
...
@@ -70,7 +70,7 @@ void Node::add_output(const element::Type& element_type, const Shape& shape)
...
@@ -70,7 +70,7 @@ void Node::add_output(const element::Type& element_type, const Shape& shape)
m_outputs
.
emplace_back
(
this
,
i
,
tensor_view_descriptor
);
m_outputs
.
emplace_back
(
this
,
i
,
tensor_view_descriptor
);
}
}
void
Node
::
set_value_type_checked
(
const
shared_ptr
<
const
Value
Type
>&
value_type
)
void
Node
::
set_value_type_checked
(
const
shared_ptr
<
const
TensorView
Type
>&
value_type
)
{
{
set_value_type_checked
(
value_type
->
get_element_type
(),
value_type
->
get_shape
());
set_value_type_checked
(
value_type
->
get_element_type
(),
value_type
->
get_shape
());
}
}
...
...
src/ngraph/node.hpp
View file @
d87b0065
...
@@ -85,7 +85,7 @@ namespace ngraph
...
@@ -85,7 +85,7 @@ namespace ngraph
// value_type agrees with the value type that was set.
// value_type agrees with the value type that was set.
// This is used when the framework specifies a value type for the value, and we
// This is used when the framework specifies a value type for the value, and we
// independently compute what we thing the value type should be from the arguments.
// independently compute what we thing the value type should be from the arguments.
void
set_value_type_checked
(
const
std
::
shared_ptr
<
const
Value
Type
>&
value_type
);
void
set_value_type_checked
(
const
std
::
shared_ptr
<
const
TensorView
Type
>&
value_type
);
void
set_value_type_checked
(
const
element
::
Type
&
element_type
,
const
Shape
&
shape
);
void
set_value_type_checked
(
const
element
::
Type
&
element_type
,
const
Shape
&
shape
);
bool
is_parameter
()
const
;
bool
is_parameter
()
const
;
...
...
src/ngraph/types/type.cpp
View file @
d87b0065
...
@@ -22,16 +22,15 @@
...
@@ -22,16 +22,15 @@
using
namespace
std
;
using
namespace
std
;
using
namespace
ngraph
;
using
namespace
ngraph
;
bool
ValueType
::
operator
!=
(
const
Value
Type
&
that
)
const
bool
TensorViewType
::
operator
!=
(
const
TensorView
Type
&
that
)
const
{
{
return
!
(
*
this
==
that
);
return
!
(
*
this
==
that
);
}
}
bool
TensorViewType
::
operator
==
(
const
Value
Type
&
that
)
const
bool
TensorViewType
::
operator
==
(
const
TensorView
Type
&
that
)
const
{
{
bool
rc
=
true
;
bool
rc
=
true
;
auto
that_tvt
=
dynamic_cast
<
const
TensorViewType
*>
(
&
that
);
auto
that_tvt
=
dynamic_cast
<
const
TensorViewType
*>
(
&
that
);
auto
that_tt
=
dynamic_cast
<
const
TupleType
*>
(
&
that
);
if
(
that_tvt
!=
nullptr
)
if
(
that_tvt
!=
nullptr
)
{
{
rc
=
true
;
rc
=
true
;
...
@@ -44,10 +43,6 @@ bool TensorViewType::operator==(const ValueType& that) const
...
@@ -44,10 +43,6 @@ bool TensorViewType::operator==(const ValueType& that) const
rc
=
false
;
rc
=
false
;
}
}
}
}
else
if
(
that_tt
!=
nullptr
)
{
rc
=
*
that_tt
==
*
this
;
}
return
rc
;
return
rc
;
}
}
...
@@ -57,75 +52,8 @@ void TensorViewType::collect_tensor_views(
...
@@ -57,75 +52,8 @@ void TensorViewType::collect_tensor_views(
views
.
push_back
(
shared_from_this
());
views
.
push_back
(
shared_from_this
());
}
}
bool
TupleType
::
operator
==
(
const
ValueType
&
that
)
const
{
auto
that_tvt
=
dynamic_cast
<
const
TupleType
*>
(
&
that
);
if
(
that_tvt
==
nullptr
)
{
return
false
;
}
vector
<
shared_ptr
<
const
ValueType
>>
this_values
=
this
->
get_element_types
();
vector
<
shared_ptr
<
const
ValueType
>>
that_values
=
that_tvt
->
get_element_types
();
bool
rc
=
this_values
.
size
()
==
that_values
.
size
();
if
(
rc
)
{
for
(
size_t
i
=
0
;
i
<
this_values
.
size
();
i
++
)
{
rc
&=
this_values
[
i
]
->
get_element_type
()
==
that_values
[
i
]
->
get_element_type
();
}
}
return
rc
;
}
void
TupleType
::
collect_tensor_views
(
std
::
vector
<
std
::
shared_ptr
<
const
TensorViewType
>>&
views
)
const
{
for
(
auto
elt
:
m_element_types
)
{
elt
->
collect_tensor_views
(
views
);
}
}
const
Shape
&
TupleType
::
get_shape
()
const
{
throw
ngraph_error
(
"get_shape() called on Tuple"
);
}
const
element
::
Type
&
TupleType
::
get_element_type
()
const
{
throw
ngraph_error
(
"get_element_type() called on Tuple"
);
}
std
::
ostream
&
ngraph
::
operator
<<
(
std
::
ostream
&
out
,
const
ValueType
&
obj
)
{
auto
tvt
=
dynamic_cast
<
const
TensorViewType
*>
(
&
obj
);
auto
tup
=
dynamic_cast
<
const
TupleType
*>
(
&
obj
);
if
(
tvt
!=
nullptr
)
{
out
<<
*
tvt
;
}
else
if
(
tup
!=
nullptr
)
{
out
<<
*
tup
;
}
else
{
out
<<
"ValueType()"
;
}
return
out
;
}
std
::
ostream
&
ngraph
::
operator
<<
(
std
::
ostream
&
out
,
const
TensorViewType
&
obj
)
std
::
ostream
&
ngraph
::
operator
<<
(
std
::
ostream
&
out
,
const
TensorViewType
&
obj
)
{
{
out
<<
"TensorViewType("
<<
obj
.
m_element_type
<<
", {"
<<
join
(
obj
.
m_shape
)
<<
"})"
;
out
<<
"TensorViewType("
<<
obj
.
m_element_type
<<
", {"
<<
join
(
obj
.
m_shape
)
<<
"})"
;
return
out
;
return
out
;
}
}
std
::
ostream
&
ngraph
::
operator
<<
(
std
::
ostream
&
out
,
const
TupleType
&
obj
)
{
out
<<
"TupleType()"
;
return
out
;
}
src/ngraph/types/type.hpp
View file @
d87b0065
...
@@ -23,48 +23,28 @@
...
@@ -23,48 +23,28 @@
namespace
ngraph
namespace
ngraph
{
{
class
TensorViewType
;
class
TensorViewType
;
class
TupleType
;
/// ValueType is
std
::
ostream
&
operator
<<
(
std
::
ostream
&
,
const
TensorViewType
&
);
/// TensorViewType
/// | TupleType(ValueType[])
class
ValueType
{
ValueType
(
const
ValueType
&
)
=
delete
;
ValueType
&
operator
=
(
const
ValueType
&
)
=
delete
;
protected
:
ValueType
()
{}
public
:
virtual
~
ValueType
()
{}
virtual
bool
operator
==
(
const
ValueType
&
that
)
const
=
0
;
bool
operator
!=
(
const
ValueType
&
that
)
const
;
/// Add tensor views in depth-first order.
virtual
void
collect_tensor_views
(
std
::
vector
<
std
::
shared_ptr
<
const
TensorViewType
>>&
views
)
const
=
0
;
virtual
const
Shape
&
get_shape
()
const
=
0
;
virtual
const
element
::
Type
&
get_element_type
()
const
=
0
;
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
,
const
ValueType
&
);
};
/// Describes a tensor view; an element type and a shape.
/// Describes a tensor view; an element type and a shape.
class
TensorViewType
:
public
ValueType
,
public
std
::
enable_shared_from_this
<
TensorViewType
>
class
TensorViewType
:
public
std
::
enable_shared_from_this
<
TensorViewType
>
{
{
TensorViewType
&
operator
=
(
const
ValueType
&
)
=
delete
;
public
:
public
:
/// /param element_type The type of the tensor elements.
/// /param element_type The type of the tensor elements.
/// /param shape The shape of the tensor.
/// /param shape The shape of the tensor.
TensorViewType
(
const
element
::
Type
&
element_type
,
const
Shape
&
shape
)
TensorViewType
(
const
element
::
Type
&
element_type
,
const
Shape
&
shape
)
:
ValueType
()
:
m_element_type
(
element_type
)
,
m_element_type
(
element_type
)
,
m_shape
(
shape
)
,
m_shape
(
shape
)
{
{
}
}
virtual
const
element
::
Type
&
get_element_type
()
const
override
{
return
m_element_type
;
}
const
element
::
Type
&
get_element_type
()
const
{
return
m_element_type
;
}
virtual
const
Shape
&
get_shape
()
const
override
{
return
m_shape
;
}
const
Shape
&
get_shape
()
const
{
return
m_shape
;
}
virtual
bool
operator
==
(
const
ValueType
&
that
)
const
override
;
bool
operator
==
(
const
TensorViewType
&
that
)
const
;
virtual
void
collect_tensor_views
(
bool
operator
!=
(
const
TensorViewType
&
that
)
const
;
std
::
vector
<
std
::
shared_ptr
<
const
TensorViewType
>>&
views
)
const
override
;
void
collect_tensor_views
(
std
::
vector
<
std
::
shared_ptr
<
const
TensorViewType
>>&
views
)
const
;
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
,
const
TensorViewType
&
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
,
const
TensorViewType
&
);
...
@@ -72,37 +52,4 @@ namespace ngraph
...
@@ -72,37 +52,4 @@ namespace ngraph
const
element
::
Type
m_element_type
;
const
element
::
Type
m_element_type
;
Shape
m_shape
;
Shape
m_shape
;
};
};
/// Describes a tuple of values; a vector of types
class
TupleType
:
public
ValueType
{
public
:
/// Construct empty tuple and add value types later.
TupleType
()
{}
/// @param element_types A vector of types for the tuple elements
TupleType
(
const
std
::
vector
<
std
::
shared_ptr
<
const
ValueType
>>&
element_types
)
:
m_element_types
(
element_types
)
{
}
const
std
::
vector
<
std
::
shared_ptr
<
const
ValueType
>>
get_element_types
()
const
{
return
m_element_types
;
}
std
::
vector
<
std
::
shared_ptr
<
const
ValueType
>>
set_element_types
()
{
return
m_element_types
;
}
virtual
const
element
::
Type
&
get_element_type
()
const
override
;
virtual
bool
operator
==
(
const
ValueType
&
that
)
const
override
;
virtual
void
collect_tensor_views
(
std
::
vector
<
std
::
shared_ptr
<
const
TensorViewType
>>&
views
)
const
override
;
virtual
const
Shape
&
get_shape
()
const
override
;
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
,
const
TupleType
&
);
protected
:
std
::
vector
<
std
::
shared_ptr
<
const
ValueType
>>
m_element_types
;
};
}
}
test/build_graph.cpp
View file @
d87b0065
...
@@ -39,24 +39,6 @@ TEST(build_graph, build_simple)
...
@@ -39,24 +39,6 @@ TEST(build_graph, build_simple)
ASSERT_EQ
(
cluster_0
->
get_output_op
(
0
),
dot
);
ASSERT_EQ
(
cluster_0
->
get_output_op
(
0
),
dot
);
}
}
// Check upcasting from ValueType.
TEST
(
build_graph
,
as_type
)
{
// Check upcasting a ValueType::ptr that is a TensorViewType to a TensorViewType and Tuple.
auto
tv_vt
=
make_shared
<
TensorViewType
>
(
element
::
f32
,
Shape
{
2
,
3
,
5
});
auto
tv_tv
=
dynamic_pointer_cast
<
TensorViewType
>
(
tv_vt
);
ASSERT_EQ
(
tv_vt
,
tv_tv
);
auto
tv_tp
=
dynamic_pointer_cast
<
TupleType
>
(
tv_vt
);
ASSERT_EQ
(
nullptr
,
tv_tp
);
// Check upcasting a ValueType::ptr that is a TupleType to a TensorViewType and Tuple.
auto
tp_vt
=
make_shared
<
TupleType
>
(
ValueTypes
{
tv_vt
,
tv_vt
});
auto
tp_tv
=
dynamic_pointer_cast
<
TensorViewType
>
(
tp_vt
);
ASSERT_EQ
(
nullptr
,
tp_tv
);
auto
tp_tp
=
dynamic_pointer_cast
<
TupleType
>
(
tp_vt
);
ASSERT_EQ
(
tp_vt
,
tp_tp
);
}
// Check node comparisons
// Check node comparisons
TEST
(
build_graph
,
node_comparison
)
TEST
(
build_graph
,
node_comparison
)
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment