Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
273599ca
Commit
273599ca
authored
Sep 15, 2017
by
Adam Procter
Committed by
GitHub
Sep 15, 2017
Browse files
Options
Browse Files
Download
Plain Diff
Merge pull request #125 from NervanaSystems/bob/remove_op
remove class Op and merge functionality with Node
parents
6a0ac42e
1cf260ec
Show whitespace changes
Inline
Side-by-side
Showing
29 changed files
with
40 additions
and
81 deletions
+40
-81
node.cpp
src/ngraph/node.cpp
+8
-11
node.hpp
src/ngraph/node.hpp
+1
-4
op.hpp
src/ngraph/op.hpp
+6
-26
abs.hpp
src/ngraph/ops/abs.hpp
+1
-1
add.hpp
src/ngraph/ops/add.hpp
+1
-1
broadcast.hpp
src/ngraph/ops/broadcast.hpp
+1
-1
ceiling.hpp
src/ngraph/ops/ceiling.hpp
+1
-1
concatenate.hpp
src/ngraph/ops/concatenate.hpp
+1
-1
convert.hpp
src/ngraph/ops/convert.hpp
+1
-1
divide.hpp
src/ngraph/ops/divide.hpp
+1
-1
dot.hpp
src/ngraph/ops/dot.hpp
+1
-1
equal.hpp
src/ngraph/ops/equal.hpp
+1
-1
exp.hpp
src/ngraph/ops/exp.hpp
+1
-1
floor.hpp
src/ngraph/ops/floor.hpp
+1
-1
greater.hpp
src/ngraph/ops/greater.hpp
+1
-1
less.hpp
src/ngraph/ops/less.hpp
+1
-1
log.hpp
src/ngraph/ops/log.hpp
+1
-1
maximum.hpp
src/ngraph/ops/maximum.hpp
+1
-1
minimum.hpp
src/ngraph/ops/minimum.hpp
+1
-1
multiply.hpp
src/ngraph/ops/multiply.hpp
+1
-1
negative.hpp
src/ngraph/ops/negative.hpp
+1
-1
power.hpp
src/ngraph/ops/power.hpp
+1
-1
remainder.hpp
src/ngraph/ops/remainder.hpp
+1
-1
subtract.hpp
src/ngraph/ops/subtract.hpp
+1
-1
tuple.hpp
src/ngraph/ops/tuple.hpp
+1
-1
visualize.cpp
src/ngraph/visualize.cpp
+1
-5
op.cpp
src/ops/op.cpp
+0
-7
op.cpp
test/op.cpp
+0
-2
topological_sort.cpp
test/topological_sort.cpp
+2
-4
No files found.
src/ngraph/node.cpp
View file @
273599ca
...
...
@@ -74,27 +74,24 @@ void Node::assign_tensors()
}
}
bool
Node
::
is_
op
()
const
bool
Node
::
is_
parameter
()
const
{
return
dynamic_cast
<
const
Op
*>
(
this
)
!=
nullptr
;
return
dynamic_cast
<
const
op
::
Parameter
*>
(
this
)
!=
nullptr
;
}
bool
Node
::
is_parameter
()
const
std
::
string
Node
::
get_node_id
()
const
{
return
dynamic_cast
<
const
op
::
Parameter
*>
(
this
)
!=
nullptr
;
stringstream
ss
;
ss
<<
description
()
<<
"_"
<<
m_instance_id
;
return
ss
.
str
();
}
namespace
ngraph
{
ostream
&
operator
<<
(
ostream
&
out
,
const
Node
&
node
)
{
auto
op_tmp
=
dynamic_cast
<
const
Op
*>
(
&
node
);
auto
parameter_tmp
=
dynamic_cast
<
const
Op
*>
(
&
node
);
if
(
op_tmp
)
{
out
<<
"Op("
<<
op_tmp
->
get_node_id
()
<<
")"
;
}
else
if
(
parameter_tmp
)
auto
parameter_tmp
=
dynamic_cast
<
const
op
::
Parameter
*>
(
&
node
);
if
(
parameter_tmp
)
{
out
<<
"Parameter("
<<
parameter_tmp
->
get_node_id
()
<<
")"
;
}
...
...
src/ngraph/node.hpp
View file @
273599ca
...
...
@@ -25,8 +25,6 @@
namespace
ngraph
{
class
Op
;
namespace
descriptor
{
class
Input
;
...
...
@@ -71,7 +69,7 @@ namespace ngraph
std
::
string
get_name
()
const
{
return
m_name
;
}
void
set_name
(
const
std
::
string
&
name
)
{
m_name
=
name
;
}
virtual
std
::
string
get_node_id
()
const
=
0
;
virtual
std
::
string
get_node_id
()
const
;
/// Return true if this has the same implementing class as node. This
/// will be used by the pattern matcher when comparing a pattern
...
...
@@ -100,7 +98,6 @@ namespace ngraph
// independently compute what we thing the value type should be from the arguments.
void
set_value_type_checked
(
const
std
::
shared_ptr
<
ValueType
>&
value_type
);
bool
is_op
()
const
;
bool
is_parameter
()
const
;
size_t
get_instance_id
()
const
{
return
m_instance_id
;
}
...
...
src/ngraph/op.hpp
View file @
273599ca
...
...
@@ -22,32 +22,12 @@
namespace
ngraph
{
/// Op nodes are nodes whose value is the result of some operation
/// applied to its arguments. For calls to user functions, the op will
/// reference the user function.
class
Op
:
public
Node
{
public
:
Op
(
const
std
::
vector
<
std
::
shared_ptr
<
Node
>>&
arguments
)
:
Node
(
arguments
)
{
}
Op
()
:
Node
()
{
}
virtual
std
::
string
get_op_class_name
()
const
=
0
;
virtual
std
::
string
get_node_id
()
const
override
;
};
// TODO: These class definitions are to be moved into separate files in the op directory
namespace
op
{
/// A Function invokes a function on node arguments. In addition to the argument
/// we need to preserve the function.
class
FunctionCall
:
public
Op
class
FunctionCall
:
public
Node
{
virtual
std
::
string
description
()
const
override
{
return
"FunctionCall"
;
}
...
...
@@ -57,14 +37,14 @@ namespace ngraph
/// The is an operation we handle directly, i.e. all type checking, etc.
/// are defined in C++ rather than in terms of ngraph operations.
class
Builtin
:
public
Op
class
Builtin
:
public
Node
{
public
:
virtual
std
::
string
description
()
const
override
{
return
"Builtin"
;
}
protected
:
Builtin
(
const
std
::
vector
<
std
::
shared_ptr
<
Node
>>&
args
)
:
Op
(
args
)
:
Node
(
args
)
{
}
};
...
...
@@ -88,7 +68,7 @@ namespace ngraph
{
}
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"Reshape"
;
}
virtual
std
::
string
description
()
const
override
{
return
"Reshape"
;
}
//virtual void propagate_types() override;
protected
:
Shape
m_shape
;
...
...
@@ -147,7 +127,7 @@ namespace ngraph
{
}
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"BinaryElementwiseComparison"
;
}
virtual
std
::
string
description
()
const
override
{
return
"BinaryElementwiseComparison"
;
}
//virtual void propagate_types() override;
virtual
const
element
::
Type
&
propagate_element_types
(
const
element
::
Type
&
arg0_element_type
,
...
...
@@ -163,7 +143,7 @@ namespace ngraph
{
}
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"BinaryElementwiseArithmetic"
;
}
virtual
std
::
string
description
()
const
override
{
return
"BinaryElementwiseArithmetic"
;
}
//virtual void propagate_types() override;
virtual
const
element
::
Type
&
propagate_element_types
(
const
element
::
Type
&
arg0_element_type
,
...
...
src/ngraph/ops/abs.hpp
View file @
273599ca
...
...
@@ -26,7 +26,7 @@ namespace ngraph
{
}
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"Abs"
;
}
virtual
std
::
string
description
()
const
override
{
return
"Abs"
;
}
};
}
}
src/ngraph/ops/add.hpp
View file @
273599ca
...
...
@@ -25,7 +25,7 @@ namespace ngraph
:
BinaryElementwiseArithmetic
(
arg0
,
arg1
)
{
}
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"Add"
;
}
virtual
std
::
string
description
()
const
override
{
return
"Add"
;
}
};
}
}
src/ngraph/ops/broadcast.hpp
View file @
273599ca
...
...
@@ -36,7 +36,7 @@ namespace ngraph
{
}
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"Broadcast"
;
}
virtual
std
::
string
description
()
const
override
{
return
"Broadcast"
;
}
virtual
void
propagate_types
()
override
;
protected
:
...
...
src/ngraph/ops/ceiling.hpp
View file @
273599ca
...
...
@@ -26,7 +26,7 @@ namespace ngraph
{
}
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"Ceiling"
;
}
virtual
std
::
string
description
()
const
override
{
return
"Ceiling"
;
}
};
}
}
src/ngraph/ops/concatenate.hpp
View file @
273599ca
...
...
@@ -26,7 +26,7 @@ namespace ngraph
{
}
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"Concatenate"
;
}
virtual
std
::
string
description
()
const
override
{
return
"Concatenate"
;
}
virtual
void
propagate_types
()
override
;
};
}
...
...
src/ngraph/ops/convert.hpp
View file @
273599ca
...
...
@@ -27,7 +27,7 @@ namespace ngraph
{
}
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"Convert"
;
}
virtual
std
::
string
description
()
const
override
{
return
"Convert"
;
}
virtual
void
propagate_types
()
override
;
protected
:
...
...
src/ngraph/ops/divide.hpp
View file @
273599ca
...
...
@@ -26,7 +26,7 @@ namespace ngraph
{
}
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"Divide"
;
}
virtual
std
::
string
description
()
const
override
{
return
"Divide"
;
}
};
}
}
src/ngraph/ops/dot.hpp
View file @
273599ca
...
...
@@ -45,7 +45,7 @@ namespace ngraph
{
}
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"Dot"
;
}
virtual
std
::
string
description
()
const
override
{
return
"Dot"
;
}
virtual
void
propagate_types
()
override
;
};
}
...
...
src/ngraph/ops/equal.hpp
View file @
273599ca
...
...
@@ -25,7 +25,7 @@ namespace ngraph
:
BinaryElementwiseComparison
(
arg0
,
arg1
)
{
}
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"Equal"
;
}
virtual
std
::
string
description
()
const
override
{
return
"Equal"
;
}
};
}
}
src/ngraph/ops/exp.hpp
View file @
273599ca
...
...
@@ -26,7 +26,7 @@ namespace ngraph
{
}
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"Exp"
;
}
virtual
std
::
string
description
()
const
override
{
return
"Exp"
;
}
};
}
}
src/ngraph/ops/floor.hpp
View file @
273599ca
...
...
@@ -26,7 +26,7 @@ namespace ngraph
{
}
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"Floor"
;
}
virtual
std
::
string
description
()
const
override
{
return
"Floor"
;
}
};
}
}
src/ngraph/ops/greater.hpp
View file @
273599ca
...
...
@@ -25,7 +25,7 @@ namespace ngraph
:
BinaryElementwiseComparison
(
arg0
,
arg1
)
{
}
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"Greater"
;
}
virtual
std
::
string
description
()
const
override
{
return
"Greater"
;
}
};
}
}
src/ngraph/ops/less.hpp
View file @
273599ca
...
...
@@ -25,7 +25,7 @@ namespace ngraph
:
BinaryElementwiseComparison
(
arg0
,
arg1
)
{
}
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"Less"
;
}
virtual
std
::
string
description
()
const
override
{
return
"Less"
;
}
};
}
}
src/ngraph/ops/log.hpp
View file @
273599ca
...
...
@@ -26,7 +26,7 @@ namespace ngraph
{
}
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"Log"
;
}
virtual
std
::
string
description
()
const
override
{
return
"Log"
;
}
};
}
}
src/ngraph/ops/maximum.hpp
View file @
273599ca
...
...
@@ -25,7 +25,7 @@ namespace ngraph
:
BinaryElementwiseArithmetic
(
arg0
,
arg1
)
{
}
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"Maximum"
;
}
virtual
std
::
string
description
()
const
override
{
return
"Maximum"
;
}
};
}
}
src/ngraph/ops/minimum.hpp
View file @
273599ca
...
...
@@ -25,7 +25,7 @@ namespace ngraph
:
BinaryElementwiseArithmetic
(
arg0
,
arg1
)
{
}
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"Minimum"
;
}
virtual
std
::
string
description
()
const
override
{
return
"Minimum"
;
}
};
}
}
src/ngraph/ops/multiply.hpp
View file @
273599ca
...
...
@@ -26,7 +26,7 @@ namespace ngraph
{
}
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"Multiply"
;
}
virtual
std
::
string
description
()
const
override
{
return
"Multiply"
;
}
};
}
}
src/ngraph/ops/negative.hpp
View file @
273599ca
...
...
@@ -26,7 +26,7 @@ namespace ngraph
{
}
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"Negative"
;
}
virtual
std
::
string
description
()
const
override
{
return
"Negative"
;
}
};
}
}
src/ngraph/ops/power.hpp
View file @
273599ca
...
...
@@ -25,7 +25,7 @@ namespace ngraph
:
BinaryElementwiseArithmetic
(
arg0
,
arg1
)
{
}
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"Power"
;
}
virtual
std
::
string
description
()
const
override
{
return
"Power"
;
}
};
}
}
src/ngraph/ops/remainder.hpp
View file @
273599ca
...
...
@@ -25,7 +25,7 @@ namespace ngraph
:
BinaryElementwiseArithmetic
(
arg0
,
arg1
)
{
}
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"Remainder"
;
}
virtual
std
::
string
description
()
const
override
{
return
"Remainder"
;
}
};
}
}
src/ngraph/ops/subtract.hpp
View file @
273599ca
...
...
@@ -26,7 +26,7 @@ namespace ngraph
{
}
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"Subtract"
;
}
virtual
std
::
string
description
()
const
override
{
return
"Subtract"
;
}
};
}
}
src/ngraph/ops/tuple.hpp
View file @
273599ca
...
...
@@ -26,7 +26,7 @@ namespace ngraph
{
}
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"Tuple"
;
}
virtual
std
::
string
description
()
const
override
{
return
"Tuple"
;
}
virtual
void
propagate_types
()
override
;
};
}
...
...
src/ngraph/visualize.cpp
View file @
273599ca
...
...
@@ -60,13 +60,9 @@ std::string Visualize::get_attributes(const Node* node)
{
ss
<<
" "
<<
node
->
get_node_id
()
<<
" [shape=box color=blue]
\n
"
;
}
else
if
(
node
->
is_op
())
{
ss
<<
" "
<<
node
->
get_node_id
()
<<
" [shape=ellipse color=black]
\n
"
;
}
else
{
ss
<<
" "
<<
node
->
get_node_id
()
<<
" [shape=
diamond color=red
]
\n
"
;
ss
<<
" "
<<
node
->
get_node_id
()
<<
" [shape=
ellipse color=black
]
\n
"
;
}
return
ss
.
str
();
}
...
...
src/ops/op.cpp
View file @
273599ca
...
...
@@ -19,10 +19,3 @@
using
namespace
ngraph
;
using
namespace
std
;
std
::
string
ngraph
::
Op
::
get_node_id
()
const
{
stringstream
ss
;
ss
<<
get_op_class_name
()
<<
"_"
<<
m_instance_id
;
return
ss
.
str
();
}
test/op.cpp
View file @
273599ca
...
...
@@ -26,7 +26,6 @@ TEST(op, is_op)
auto
arg0
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
Shape
{
1
});
ASSERT_NE
(
nullptr
,
arg0
);
EXPECT_TRUE
(
arg0
->
is_parameter
());
EXPECT_FALSE
(
arg0
->
is_op
());
}
TEST
(
op
,
is_parameter
)
...
...
@@ -36,5 +35,4 @@ TEST(op, is_parameter)
auto
t0
=
make_shared
<
op
::
Add
>
(
arg0
,
arg0
);
ASSERT_NE
(
nullptr
,
t0
);
EXPECT_FALSE
(
t0
->
is_parameter
());
EXPECT_TRUE
(
t0
->
is_op
());
}
test/topological_sort.cpp
View file @
273599ca
...
...
@@ -107,17 +107,15 @@ TEST(benchmark, topological_sort)
result
=
make_cell
(
result
,
in_1
,
in_2
);
}
auto
op_r0
=
static_pointer_cast
<
Op
>
(
result
);
timer
.
start
();
pass
::
TopologicalSort
ts
;
ts
.
run_on_tree
(
op_r0
);
ts
.
run_on_tree
(
result
);
auto
sorted_list
=
ts
.
get_call_graph
();
timer
.
stop
();
INFO
<<
"topological sort took "
<<
timer
.
get_milliseconds
()
<<
"ms"
;
size_t
node_count
=
0
;
traverse_nodes
(
op_r0
,
[
&
](
const
Node
*
node
)
{
traverse_nodes
(
result
,
[
&
](
const
Node
*
node
)
{
node_count
++
;
});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment