Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
40ddf45a
Unverified
Commit
40ddf45a
authored
Aug 20, 2018
by
Scott Cyphers
Committed by
GitHub
Aug 20, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Allow tensor[view, layout] element type and shape to be modified (#1440)
parent
e63ffa29
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
82 additions
and
21 deletions
+82
-21
tensor_view_layout.cpp
src/ngraph/descriptor/layout/tensor_view_layout.cpp
+11
-3
tensor_view_layout.hpp
src/ngraph/descriptor/layout/tensor_view_layout.hpp
+4
-1
output.hpp
src/ngraph/descriptor/output.hpp
+4
-0
primary_tensor_view.cpp
src/ngraph/descriptor/primary_tensor_view.cpp
+6
-0
primary_tensor_view.hpp
src/ngraph/descriptor/primary_tensor_view.hpp
+2
-0
tensor.cpp
src/ngraph/descriptor/tensor.cpp
+5
-0
tensor.hpp
src/ngraph/descriptor/tensor.hpp
+2
-1
tensor_view.cpp
src/ngraph/descriptor/tensor_view.cpp
+21
-0
tensor_view.hpp
src/ngraph/descriptor/tensor_view.hpp
+7
-0
element_type.cpp
src/ngraph/type/element_type.cpp
+13
-10
element_type.hpp
src/ngraph/type/element_type.hpp
+7
-6
No files found.
src/ngraph/descriptor/layout/tensor_view_layout.cpp
View file @
40ddf45a
...
...
@@ -22,16 +22,24 @@
using
namespace
ngraph
;
descriptor
::
layout
::
TensorViewLayout
::
TensorViewLayout
(
const
descriptor
::
TensorView
&
tensor_view
)
:
m_tensor_view_type
(
tensor_view
.
get_tensor_view_type
())
:
m_element_type
(
tensor_view
.
get_element_type
())
,
m_shape
(
tensor_view
.
get_shape
())
{
}
const
element
::
Type
&
descriptor
::
layout
::
TensorViewLayout
::
get_element_type
()
const
{
return
m_
tensor_view_type
->
get_element_type
()
;
return
m_
element_type
;
}
const
Shape
&
descriptor
::
layout
::
TensorViewLayout
::
get_shape
()
const
{
return
m_tensor_view_type
->
get_shape
();
return
m_shape
;
}
void
descriptor
::
layout
::
TensorViewLayout
::
set_tensor_view_type
(
const
element
::
Type
&
element_type
,
const
Shape
&
shape
)
{
m_element_type
=
element_type
;
m_shape
=
shape
;
}
src/ngraph/descriptor/layout/tensor_view_layout.hpp
View file @
40ddf45a
...
...
@@ -62,8 +62,11 @@ namespace ngraph
/// @brief Return true if this and other have the same element interpretation
virtual
bool
operator
==
(
const
TensorViewLayout
&
other
)
const
=
0
;
bool
operator
!=
(
const
TensorViewLayout
&
other
)
const
{
return
!
(
*
this
==
other
);
}
void
set_tensor_view_type
(
const
element
::
Type
&
element_type
,
const
Shape
&
shape
);
protected
:
std
::
shared_ptr
<
const
TensorViewType
>
m_tensor_view_type
;
element
::
Type
m_element_type
;
Shape
m_shape
;
};
}
}
...
...
src/ngraph/descriptor/output.hpp
View file @
40ddf45a
...
...
@@ -45,6 +45,10 @@ namespace ngraph
std
::
shared_ptr
<
Node
>
get_node
()
const
;
size_t
get_index
()
const
{
return
m_index
;
}
std
::
shared_ptr
<
TensorView
>
get_tensor_view
()
const
{
return
m_tensor_view
;
}
void
set_tensor_view
(
const
std
::
shared_ptr
<
TensorView
>&
tensor_view
)
{
m_tensor_view
=
tensor_view
;
}
void
add_input
(
Input
*
input
);
void
remove_input
(
Input
*
input
);
const
std
::
set
<
Input
*>&
get_inputs
()
const
{
return
m_inputs
;
}
...
...
src/ngraph/descriptor/primary_tensor_view.cpp
View file @
40ddf45a
...
...
@@ -38,3 +38,9 @@ Tensor& PrimaryTensorView::get_tensor()
{
return
m_tensor
;
}
void
PrimaryTensorView
::
set_tensor_view_type
(
const
element
::
Type
&
element_type
,
const
Shape
&
shape
)
{
TensorView
::
set_tensor_view_type
(
element_type
,
shape
);
m_tensor
.
set_element_type
(
element_type
);
}
src/ngraph/descriptor/primary_tensor_view.hpp
View file @
40ddf45a
...
...
@@ -39,6 +39,8 @@ namespace ngraph
virtual
const
Tensor
&
get_tensor
()
const
override
;
virtual
Tensor
&
get_tensor
()
override
;
void
set_tensor_view_type
(
const
element
::
Type
&
element_type
,
const
Shape
&
shape
)
override
;
protected
:
Tensor
m_tensor
;
...
...
src/ngraph/descriptor/tensor.cpp
View file @
40ddf45a
...
...
@@ -62,6 +62,11 @@ size_t descriptor::Tensor::get_pool_offset() const
return
m_pool_offset
;
}
void
descriptor
::
Tensor
::
set_element_type
(
const
element
::
Type
&
element_type
)
{
m_element_type
=
element_type
;
}
ostream
&
operator
<<
(
ostream
&
out
,
const
descriptor
::
Tensor
&
tensor
)
{
out
<<
"Tensor("
<<
tensor
.
get_name
()
<<
")"
;
...
...
src/ngraph/descriptor/tensor.hpp
View file @
40ddf45a
...
...
@@ -58,10 +58,11 @@ public:
void
set_pool_offset
(
size_t
);
size_t
get_pool_offset
()
const
;
const
element
::
Type
&
get_element_type
()
const
{
return
m_element_type
;
}
void
set_element_type
(
const
element
::
Type
&
element_type
);
static
std
::
string
make_tensor_name
(
const
Node
*
node
,
size_t
value_index
);
protected
:
const
element
::
Type
m_element_type
;
element
::
Type
m_element_type
;
PrimaryTensorView
*
m_primary_tensor_view
;
std
::
string
m_name
;
size_t
m_next_view_id
;
...
...
src/ngraph/descriptor/tensor_view.cpp
View file @
40ddf45a
...
...
@@ -15,6 +15,7 @@
*******************************************************************************/
#include "ngraph/descriptor/tensor_view.hpp"
#include "ngraph/descriptor/layout/tensor_view_layout.hpp"
#include "ngraph/type/type.hpp"
using
namespace
ngraph
;
...
...
@@ -24,3 +25,23 @@ shared_ptr<const ngraph::TensorViewType> descriptor::TensorView::get_value_type(
{
return
m_tensor_view_type
;
}
const
element
::
Type
&
descriptor
::
TensorView
::
get_element_type
()
const
{
return
m_tensor_view_type
->
get_element_type
();
}
const
Shape
&
descriptor
::
TensorView
::
get_shape
()
const
{
return
m_tensor_view_type
->
get_shape
();
}
void
descriptor
::
TensorView
::
set_tensor_view_type
(
const
element
::
Type
&
element_type
,
const
Shape
&
shape
)
{
m_tensor_view_type
=
make_shared
<
ngraph
::
TensorViewType
>
(
element_type
,
shape
);
if
(
nullptr
!=
m_tensor_view_layout
)
{
m_tensor_view_layout
->
set_tensor_view_type
(
element_type
,
shape
);
}
}
src/ngraph/descriptor/tensor_view.hpp
View file @
40ddf45a
...
...
@@ -20,6 +20,7 @@
#include <string>
#include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp"
namespace
ngraph
{
...
...
@@ -62,6 +63,12 @@ namespace ngraph
return
m_tensor_view_type
;
}
virtual
void
set_tensor_view_type
(
const
element
::
Type
&
element_type
,
const
Shape
&
shape
);
const
element
::
Type
&
get_element_type
()
const
;
const
Shape
&
get_shape
()
const
;
const
std
::
shared_ptr
<
layout
::
TensorViewLayout
>&
get_tensor_view_layout
()
const
{
return
m_tensor_view_layout
;
...
...
src/ngraph/type/element_type.cpp
View file @
40ddf45a
...
...
@@ -15,11 +15,13 @@
*******************************************************************************/
#include <cmath>
#include <iostream>
#include "ngraph/type/element_type.hpp"
using
namespace
ngraph
;
const
element
::
Type
element
::
unspecified
(
0
,
false
,
false
,
"unspecified"
);
const
element
::
Type
element
::
boolean
(
8
,
false
,
true
,
"char"
);
const
element
::
Type
element
::
f32
(
32
,
true
,
true
,
"float"
);
const
element
::
Type
element
::
f64
(
64
,
true
,
true
,
"double"
);
...
...
@@ -48,14 +50,6 @@ std::vector<const element::Type*> element::Type::get_known_types()
return
rc
;
}
element
::
Type
::
Type
()
:
m_bitwidth
{
0
}
,
m_is_real
{
0
}
,
m_is_signed
{
0
}
,
m_cname
{}
{
}
element
::
Type
::
Type
(
size_t
bitwidth
,
bool
is_real
,
bool
is_signed
,
const
std
::
string
&
cname
)
:
m_bitwidth
{
bitwidth
}
,
m_is_real
{
is_real
}
...
...
@@ -64,6 +58,15 @@ element::Type::Type(size_t bitwidth, bool is_real, bool is_signed, const std::st
{
}
element
::
Type
&
element
::
Type
::
operator
=
(
const
element
::
Type
&
t
)
{
m_bitwidth
=
t
.
m_bitwidth
;
m_is_real
=
t
.
m_is_real
;
m_is_signed
=
t
.
m_is_signed
;
m_cname
=
t
.
m_cname
;
return
*
this
;
}
const
std
::
string
&
element
::
Type
::
c_type_string
()
const
{
return
m_cname
;
...
...
@@ -170,7 +173,7 @@ namespace ngraph
std
::
ostream
&
element
::
operator
<<
(
std
::
ostream
&
out
,
const
element
::
Type
&
obj
)
{
out
<<
"element::Type
(
"
<<
obj
.
m_bitwidth
<<
", "
<<
obj
.
m_is_real
<<
", "
<<
obj
.
m_is_signed
<<
"
)
"
;
out
<<
"element::Type
{
"
<<
obj
.
m_bitwidth
<<
", "
<<
obj
.
m_is_real
<<
", "
<<
obj
.
m_is_signed
<<
"
,"
<<
obj
.
m_cname
<<
"}
"
;
return
out
;
}
src/ngraph/type/element_type.hpp
View file @
40ddf45a
...
...
@@ -33,6 +33,7 @@ namespace ngraph
{
class
Type
;
extern
const
Type
unspecified
;
extern
const
Type
boolean
;
extern
const
Type
f32
;
extern
const
Type
f64
;
...
...
@@ -48,10 +49,10 @@ namespace ngraph
class
Type
{
public
:
Type
()
;
Type
()
{}
Type
(
const
Type
&
)
=
default
;
Type
(
size_t
bitwidth
,
bool
is_real
,
bool
is_signed
,
const
std
::
string
&
cname
);
Type
&
operator
=
(
const
Type
&
)
=
default
;
Type
&
operator
=
(
const
Type
&
);
virtual
~
Type
()
{}
const
std
::
string
&
c_type_string
()
const
;
size_t
size
()
const
;
...
...
@@ -68,10 +69,10 @@ namespace ngraph
/// Returns true if the type is floating point, else false.
bool
get_is_real
()
const
{
return
m_is_real
;
}
private
:
size_t
m_bitwidth
;
bool
m_is_real
;
bool
m_is_signed
;
std
::
string
m_cname
;
size_t
m_bitwidth
{
0
}
;
bool
m_is_real
{
false
}
;
bool
m_is_signed
{
false
}
;
std
::
string
m_cname
{
"unspecified"
}
;
};
template
<
typename
T
>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment