Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
b839784e
Commit
b839784e
authored
Aug 31, 2017
by
Scott Cyphers
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Use value_type instead of type to be consistent with STL
Use direct implementation of is_parameter, is_op
parent
92c4d314
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
28 additions
and
36 deletions
+28
-36
element_type.hpp
src/ngraph/element_type.hpp
+1
-1
node.cpp
src/ngraph/node.cpp
+0
-10
node.hpp
src/ngraph/node.hpp
+2
-2
op.hpp
src/ngraph/op.hpp
+2
-1
constant.hpp
src/ngraph/ops/constant.hpp
+5
-5
parameter.hpp
src/ngraph/ops/parameter.hpp
+1
-0
type.hpp
src/ngraph/type.hpp
+8
-8
broadcast.cpp
src/ops/broadcast.cpp
+2
-2
dot.cpp
src/ops/dot.cpp
+3
-3
build_graph.cpp
test/build_graph.cpp
+4
-4
No files found.
src/ngraph/element_type.hpp
View file @
b839784e
...
...
@@ -66,7 +66,7 @@ namespace ngraph
public
:
// This is the C++ type used to hold a value of this element type during compilation
using
c
type
=
T
;
using
type
=
T
;
// This is a reference to an instance of this element type.
static
const
U
&
element_type
(){
static
U
t
;
...
...
src/ngraph/node.cpp
View file @
b839784e
...
...
@@ -29,16 +29,6 @@ ngraph::Node::Node(const std::vector<Node::ptr>& arguments, ValueType::ptr type)
}
}
bool
ngraph
::
Node
::
is_op
()
const
{
return
dynamic_cast
<
const
ngraph
::
Op
*>
(
this
)
!=
nullptr
;
}
bool
ngraph
::
Node
::
is_parameter
()
const
{
return
dynamic_cast
<
const
ngraph
::
Parameter
*>
(
this
)
!=
nullptr
;
}
std
::
ostream
&
ngraph
::
operator
<<
(
std
::
ostream
&
out
,
const
ngraph
::
Node
&
node
)
{
auto
op_tmp
=
dynamic_cast
<
const
ngraph
::
Op
*>
(
&
node
);
...
...
src/ngraph/node.hpp
View file @
b839784e
...
...
@@ -67,8 +67,8 @@ namespace ngraph
return
typeid
(
*
this
)
==
typeid
(
*
node
.
get
());
}
bool
is_op
()
const
;
bool
is_parameter
()
const
;
virtual
bool
is_op
()
const
{
return
false
;
}
;
virtual
bool
is_parameter
()
const
{
return
false
;
}
;
size_t
instance_id
()
const
{
return
m_instance_id
;
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
,
const
Node
&
);
...
...
src/ngraph/op.hpp
View file @
b839784e
...
...
@@ -75,7 +75,8 @@ namespace ngraph
}
virtual
std
::
string
op_class_name
()
const
=
0
;
virtual
std
::
string
node_id
()
const
;
virtual
std
::
string
node_id
()
const
override
;
virtual
bool
is_op
()
const
override
{
return
true
;
}
};
/**
...
...
src/ngraph/ops/constant.hpp
View file @
b839784e
...
...
@@ -41,15 +41,15 @@ namespace ngraph
// The ngraph element type
using
element_type
=
T
;
// The C++ type that holds the element type
using
ctype
=
typename
T
::
c
type
;
using
type
=
typename
T
::
type
;
ScalarConstant
(
typename
T
::
c
type
value
)
ScalarConstant
(
typename
T
::
type
value
)
:
ScalarConstantBase
(
std
::
make_shared
<
TensorViewType
>
(
T
::
element_type
(),
Shape
{}))
,
m_value
(
value
)
{
}
virtual
std
::
string
description
()
const
override
{
return
"
ConstantScalar
"
;
}
virtual
std
::
string
description
()
const
override
{
return
"
ScalarConstant
"
;
}
virtual
std
::
string
node_id
()
const
override
{
std
::
stringstream
ss
;
...
...
@@ -57,7 +57,7 @@ namespace ngraph
return
ss
.
str
();
}
typename
T
::
c
type
value
()
const
{
return
m_value
;
}
typename
T
::
type
value
()
const
{
return
m_value
;
}
// Make a constant from any value that can be converted to the C++ type we use
// to represent the values.
...
...
@@ -68,7 +68,7 @@ namespace ngraph
}
protected
:
typename
T
::
c
type
m_value
;
typename
T
::
type
m_value
;
};
using
FloatScalarConstant
=
ScalarConstant
<
element
::
Float
>
;
...
...
src/ngraph/ops/parameter.hpp
View file @
b839784e
...
...
@@ -41,6 +41,7 @@ namespace ngraph
std
::
string
description
()
const
override
{
return
"Parameter"
;
}
virtual
void
propagate_types
()
override
;
virtual
std
::
string
node_id
()
const
override
;
virtual
bool
is_parameter
()
const
override
{
return
true
;
};
protected
:
Function
*
m_function
;
...
...
src/ngraph/type.hpp
View file @
b839784e
...
...
@@ -112,8 +112,8 @@ namespace ngraph
class
TypedValueMixin
{
public
:
TypedValueMixin
(
const
ValueType
::
ptr
&
type
=
nullptr
)
:
m_
type
(
type
)
TypedValueMixin
(
const
ValueType
::
ptr
&
value_
type
=
nullptr
)
:
m_
value_type
(
value_
type
)
{
}
...
...
@@ -121,26 +121,26 @@ namespace ngraph
** Set the type
** /param type The new type
**/
void
type
(
const
ValueType
::
ptr
&
type
)
{
m_type
=
type
;
}
void
value_type
(
const
ValueType
::
ptr
&
value_type
)
{
m_value_type
=
value_
type
;
}
/**
** Set the type to be a tensor view type
** /param element_type The type of the tensor elements
** /param shape The shape of the view
**/
void
type
(
const
element
::
Type
&
element_type
,
const
Shape
&
shape
)
void
value_
type
(
const
element
::
Type
&
element_type
,
const
Shape
&
shape
)
{
m_type
=
std
::
make_shared
<
TensorViewType
>
(
element_type
,
shape
);
m_
value_
type
=
std
::
make_shared
<
TensorViewType
>
(
element_type
,
shape
);
}
/**
** The type associated with this value.
**/
ValueType
::
ptr
type
()
{
return
m
_type
;
}
ValueType
::
ptr
value_type
()
{
return
m_value
_type
;
}
/**
** The type associated with this value.
**/
const
ValueType
::
ptr
type
()
const
{
return
m
_type
;
}
const
ValueType
::
ptr
value_type
()
const
{
return
m_value
_type
;
}
protected
:
ValueType
::
ptr
m_type
;
ValueType
::
ptr
m_
value_
type
;
};
}
src/ops/broadcast.cpp
View file @
b839784e
...
...
@@ -32,7 +32,7 @@ Node::ptr ngraph::op::broadcast(const Node::ptr& tensor,
void
BroadcastOp
::
propagate_types
()
{
auto
arg_type
=
m_arguments
.
at
(
0
)
->
type
();
auto
arg_type
=
m_arguments
.
at
(
0
)
->
value_
type
();
if
(
nullptr
==
arg_type
)
{
throw
ngraph_error
(
"Argument to broadcast is missing type."
);
...
...
@@ -53,5 +53,5 @@ void BroadcastOp::propagate_types()
}
// TODO If m_type is already set (by framework), this should verify that the type
// we expect is consistent with the type the framework expects.
m_type
=
make_shared
<
TensorViewType
>
(
arg_tensor_view_type
->
element_type
(),
m_shape
);
m_
value_
type
=
make_shared
<
TensorViewType
>
(
arg_tensor_view_type
->
element_type
(),
m_shape
);
}
src/ops/dot.cpp
View file @
b839784e
...
...
@@ -27,8 +27,8 @@ Node::ptr ngraph::op::dot(const Node::ptr& arg0, const Node::ptr& arg1)
void
DotOp
::
propagate_types
()
{
auto
arg0_tensor_type
=
dynamic_pointer_cast
<
TensorViewType
>
(
m_arguments
.
at
(
0
)
->
type
());
auto
arg1_tensor_type
=
dynamic_pointer_cast
<
TensorViewType
>
(
m_arguments
.
at
(
1
)
->
type
());
auto
arg0_tensor_type
=
dynamic_pointer_cast
<
TensorViewType
>
(
m_arguments
.
at
(
0
)
->
value_
type
());
auto
arg1_tensor_type
=
dynamic_pointer_cast
<
TensorViewType
>
(
m_arguments
.
at
(
1
)
->
value_
type
());
if
(
nullptr
==
arg0_tensor_type
||
nullptr
==
arg1_tensor_type
)
{
throw
ngraph_error
(
"Arguments to dot must be tensor views"
);
...
...
@@ -60,5 +60,5 @@ void DotOp::propagate_types()
copy
(
arg0_shape
.
begin
(),
arg0_shape
.
begin
()
+
arg1_reduction
,
result_shape
.
end
());
copy
(
arg1_shape
.
begin
(),
arg1_shape
.
begin
()
+
arg1_reduction
,
result_shape
.
end
());
copy
(
arg1_shape
.
begin
()
+
arg1_reduction
,
arg1_shape
.
end
(),
result_shape
.
end
());
m_type
=
make_shared
<
TensorViewType
>
(
arg0_tensor_type
->
element_type
(),
result_shape
);
m_
value_
type
=
make_shared
<
TensorViewType
>
(
arg0_tensor_type
->
element_type
(),
result_shape
);
}
test/build_graph.cpp
View file @
b839784e
...
...
@@ -93,7 +93,7 @@ TEST(build_graph, literal)
auto
float0
=
FloatScalarConstant
::
make
(
3.0
);
auto
float_scalar_type
=
make_shared
<
TensorViewType
>
(
element
::
Float
::
element_type
(),
Shape
{});
ASSERT_EQ
(
float0
->
value
(),
3.0
);
ASSERT_EQ
(
*
float0
->
type
(),
float_scalar_type
);
ASSERT_EQ
(
*
float0
->
value_
type
(),
float_scalar_type
);
auto
d
=
op
::
dot
(
float0
,
float0
);
ASSERT_EQ
(
d
->
arguments
().
at
(
0
),
float0
);
ASSERT_EQ
(
d
->
arguments
().
at
(
1
),
float0
);
...
...
@@ -101,13 +101,13 @@ TEST(build_graph, literal)
// float scalar from an int
auto
float1
=
FloatScalarConstant
::
make
(
3
);
ASSERT_EQ
(
float1
->
value
(),
3
);
ASSERT_EQ
(
*
float1
->
type
(),
float_scalar_type
);
ASSERT_EQ
(
*
float1
->
value_
type
(),
float_scalar_type
);
auto
int32_0
=
Int32ScalarConstant
::
make
(
3.0
);
auto
int32_scalar_type
=
make_shared
<
TensorViewType
>
(
element
::
Int32
::
element_type
(),
Shape
{});
ASSERT_EQ
(
int32_0
->
value
(),
3
);
ASSERT_EQ
(
*
int32_0
->
type
(),
int32_scalar_type
);
ASSERT_NE
(
*
int32_0
->
type
(),
float_scalar_type
);
ASSERT_EQ
(
*
int32_0
->
value_
type
(),
int32_scalar_type
);
ASSERT_NE
(
*
int32_0
->
value_
type
(),
float_scalar_type
);
}
// Check argument inverses
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment