Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
36e36e7f
Commit
36e36e7f
authored
Sep 02, 2017
by
Robert Kimball
Committed by
GitHub
Sep 02, 2017
Browse files
Options
Browse Files
Download
Plain Diff
Merge pull request #78 from NervanaSystems/cyphers/morenames
De-use and cleanup op names.
parents
973b3a0e
c7ef13f5
Hide whitespace changes
Inline
Side-by-side
Showing
23 changed files
with
389 additions
and
626 deletions
+389
-626
common.hpp
src/ngraph/common.hpp
+6
-10
function.hpp
src/ngraph/function.hpp
+5
-15
node.cpp
src/ngraph/node.cpp
+1
-1
op.hpp
src/ngraph/op.hpp
+195
-242
broadcast.hpp
src/ngraph/ops/broadcast.hpp
+23
-27
concatenate.hpp
src/ngraph/ops/concatenate.hpp
+10
-12
constant.hpp
src/ngraph/ops/constant.hpp
+44
-41
convert.hpp
src/ngraph/ops/convert.hpp
+13
-16
dot.hpp
src/ngraph/ops/dot.hpp
+11
-14
parameter.hpp
src/ngraph/ops/parameter.hpp
+26
-34
tuple.hpp
src/ngraph/ops/tuple.hpp
+10
-12
broadcast.cpp
src/ops/broadcast.cpp
+2
-13
concatenate.cpp
src/ops/concatenate.cpp
+2
-7
constant.cpp
src/ops/constant.cpp
+1
-1
convert.cpp
src/ops/convert.cpp
+2
-8
dot.cpp
src/ops/dot.cpp
+2
-9
function.cpp
src/ops/function.cpp
+1
-13
op.cpp
src/ops/op.cpp
+0
-100
parameter.cpp
src/ops/parameter.cpp
+2
-13
tuple.cpp
src/ops/tuple.cpp
+2
-7
build_graph.cpp
test/build_graph.cpp
+19
-19
op.cpp
test/op.cpp
+3
-3
topological_sort.cpp
test/topological_sort.cpp
+9
-9
No files found.
src/ngraph/common.hpp
View file @
36e36e7f
...
@@ -22,14 +22,13 @@
...
@@ -22,14 +22,13 @@
namespace
ngraph
namespace
ngraph
{
{
class
Node
;
class
Node
;
class
Parameter
;
namespace
op
{
class
ValueType
;
class
Parameter
;
template
<
typename
T
,
typename
...
A
>
/// A list of parameters
std
::
shared_ptr
<
T
>
node
(
A
&&
...
args
)
using
Parameters
=
std
::
vector
<
std
::
shared_ptr
<
Parameter
>>
;
{
return
std
::
make_shared
<
T
>
(
args
...);
}
}
class
ValueType
;
/// Zero or more value types
/// Zero or more value types
using
ValueTypes
=
std
::
vector
<
std
::
shared_ptr
<
ValueType
>>
;
using
ValueTypes
=
std
::
vector
<
std
::
shared_ptr
<
ValueType
>>
;
...
@@ -42,7 +41,4 @@ namespace ngraph
...
@@ -42,7 +41,4 @@ namespace ngraph
/// A set of axes, for example, reduction axes
/// A set of axes, for example, reduction axes
using
AxisSet
=
std
::
set
<
size_t
>
;
using
AxisSet
=
std
::
set
<
size_t
>
;
/// A list of parameters
using
Parameters
=
std
::
vector
<
std
::
shared_ptr
<
Parameter
>>
;
}
}
src/ngraph/function.hpp
View file @
36e36e7f
...
@@ -26,28 +26,18 @@ namespace ngraph
...
@@ -26,28 +26,18 @@ namespace ngraph
{
{
public
:
public
:
Function
(
const
std
::
shared_ptr
<
Node
>&
result
,
Function
(
const
std
::
shared_ptr
<
Node
>&
result
,
const
std
::
vector
<
std
::
shared_ptr
<
Parameter
>>&
parameters
);
const
std
::
vector
<
std
::
shared_ptr
<
op
::
Parameter
>>&
parameters
);
std
::
shared_ptr
<
Node
>
get_result
()
{
return
m_result
;
}
std
::
shared_ptr
<
Node
>
get_result
()
{
return
m_result
;
}
const
std
::
vector
<
std
::
shared_ptr
<
Parameter
>>
get_parameters
()
const
const
std
::
vector
<
std
::
shared_ptr
<
op
::
Parameter
>>
get_parameters
()
const
{
{
return
m_parameters
;
return
m_parameters
;
}
}
std
::
string
get_name
()
const
{
return
m_name
;
}
std
::
string
get_name
()
const
{
return
m_name
;
}
protected
:
protected
:
std
::
shared_ptr
<
Node
>
m_result
;
std
::
shared_ptr
<
Node
>
m_result
;
std
::
vector
<
std
::
shared_ptr
<
ngraph
::
Parameter
>>
m_parameters
;
std
::
vector
<
std
::
shared_ptr
<
ngraph
::
op
::
Parameter
>>
m_parameters
;
std
::
string
m_name
;
std
::
string
m_name
;
};
};
namespace
op
{
std
::
shared_ptr
<
Function
>
function
(
const
std
::
shared_ptr
<
Node
>&
result
,
const
std
::
initializer_list
<
std
::
shared_ptr
<
Parameter
>>&
parameters
);
std
::
shared_ptr
<
Function
>
function
(
const
std
::
shared_ptr
<
Node
>&
result
,
const
std
::
vector
<
std
::
shared_ptr
<
Parameter
>>&
parameters
);
}
}
}
src/ngraph/node.cpp
View file @
36e36e7f
...
@@ -37,7 +37,7 @@ bool ngraph::Node::is_op() const
...
@@ -37,7 +37,7 @@ bool ngraph::Node::is_op() const
bool
ngraph
::
Node
::
is_parameter
()
const
bool
ngraph
::
Node
::
is_parameter
()
const
{
{
return
dynamic_cast
<
const
ngraph
::
Parameter
*>
(
this
)
!=
nullptr
;
return
dynamic_cast
<
const
ngraph
::
op
::
Parameter
*>
(
this
)
!=
nullptr
;
}
}
namespace
ngraph
namespace
ngraph
...
...
src/ngraph/op.hpp
View file @
36e36e7f
...
@@ -22,57 +22,6 @@
...
@@ -22,57 +22,6 @@
namespace
ngraph
namespace
ngraph
{
{
namespace
op
{
std
::
shared_ptr
<
Node
>
abs
(
const
std
::
shared_ptr
<
Node
>&
arg
);
std
::
shared_ptr
<
Node
>
add
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
);
std
::
shared_ptr
<
Node
>
ceiling
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
);
//std::shared_ptr<Node> convert();
//std::shared_ptr<Node> convolution();
std
::
shared_ptr
<
Node
>
divide
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
);
std
::
shared_ptr
<
Node
>
equal
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
);
std
::
shared_ptr
<
Node
>
exp
(
const
std
::
shared_ptr
<
Node
>&
arg0
);
std
::
shared_ptr
<
Node
>
floor
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
);
//std::shared_ptr<Node> get_tuple_element();
std
::
shared_ptr
<
Node
>
greater
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
);
//std::shared_ptr<Node> greater_equal(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
std
::
shared_ptr
<
Node
>
less
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
);
//std::shared_ptr<Node> less_equal(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
std
::
shared_ptr
<
Node
>
log
(
const
std
::
shared_ptr
<
Node
>&
arg0
);
//std::shared_ptr<Node> logical(); and, or, not
std
::
shared_ptr
<
Node
>
maximum
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
);
std
::
shared_ptr
<
Node
>
minimum
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
);
std
::
shared_ptr
<
Node
>
multiply
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
);
std
::
shared_ptr
<
Node
>
negative
(
const
std
::
shared_ptr
<
Node
>&
arg0
);
//std::shared_ptr<Node> pad();
std
::
shared_ptr
<
Node
>
power
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
);
//std::shared_ptr<Node> reduce();
// std::shared_ptr<Node> reduce_window();
std
::
shared_ptr
<
Node
>
remainder
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
);
std
::
shared_ptr
<
Node
>
reshape
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
Shape
&
shape
);
//std::shared_ptr<Node> reverse();
//std::shared_ptr<Node> rng();
//std::shared_ptr<Node> select();
//std::shared_ptr<Node> select_scatter();
//std::shared_ptr<Node> slice();
std
::
shared_ptr
<
Node
>
subtract
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
);
//std::shared_ptr<Node> transpose();
//std::shared_ptr<Node> while();
}
/// Op nodes are nodes whose value is the result of some operation
/// Op nodes are nodes whose value is the result of some operation
/// applied to its arguments. For calls to user functions, the op will
/// applied to its arguments. For calls to user functions, the op will
/// reference the user function.
/// reference the user function.
...
@@ -93,249 +42,253 @@ namespace ngraph
...
@@ -93,249 +42,253 @@ namespace ngraph
virtual
std
::
string
get_node_id
()
const
override
;
virtual
std
::
string
get_node_id
()
const
override
;
};
};
/// A FunctionOp invokes a function on node arguments. In addition to the argument
// TODO: These class definitions are to be moved into separate files in the op directory
/// we need to preserve the function.
namespace
op
class
FunctionOp
:
public
Op
{
{
virtual
std
::
string
description
()
const
override
{
return
"FunctionOp"
;
}
/// A Function invokes a function on node arguments. In addition to the argument
/// we need to preserve the function.
class
FunctionCall
:
public
Op
{
virtual
std
::
string
description
()
const
override
{
return
"FunctionCall"
;
}
protected
:
protected
:
std
::
shared_ptr
<
Node
>
m_function
;
std
::
shared_ptr
<
Node
>
m_function
;
};
};
/// The is an operation we handle directly, i.e. all type checking, etc.
/// The is an operation we handle directly, i.e. all type checking, etc.
/// are defined in C++ rather than in terms of ngraph operations.
/// are defined in C++ rather than in terms of ngraph operations.
class
BuiltinOp
:
public
Op
class
Builtin
:
public
Op
{
{
public
:
public
:
virtual
std
::
string
description
()
const
override
{
return
"BuiltinOp
"
;
}
virtual
std
::
string
description
()
const
override
{
return
"Builtin
"
;
}
/// Name of the builtin op, for debugging and logging.
/// Name of the builtin op, for debugging and logging.
// TODO: Implement for each op. This enables graphs to be built for now.
// TODO: Implement for each op. This enables graphs to be built for now.
virtual
void
propagate_types
()
override
{}
virtual
void
propagate_types
()
override
{}
protected
:
protected
:
BuiltinOp
(
const
std
::
vector
<
std
::
shared_ptr
<
Node
>>&
args
)
Builtin
(
const
std
::
vector
<
std
::
shared_ptr
<
Node
>>&
args
)
:
Op
(
args
)
:
Op
(
args
)
{
{
}
}
};
};
class
AbsOp
:
public
BuiltinOp
class
Abs
:
public
Builtin
{
public
:
AbsOp
(
const
std
::
shared_ptr
<
Node
>&
arg0
)
:
BuiltinOp
({
arg0
})
{
{
}
public
:
Abs
(
const
std
::
shared_ptr
<
Node
>&
arg0
)
:
Builtin
({
arg0
})
{
}
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"a
bs"
;
}
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"A
bs"
;
}
//virtual void propagate_types() override;
//virtual void propagate_types() override;
};
};
class
AddOp
:
public
BuiltinOp
class
Add
:
public
Builtin
{
public
:
AddOp
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
:
BuiltinOp
({
arg0
,
arg1
})
{
{
}
public
:
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"add"
;
}
Add
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
//virtual void propagate_types() override;
:
Builtin
({
arg0
,
arg1
})
};
{
}
class
CeilingOp
:
public
BuiltinOp
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"Add"
;
}
{
//virtual void propagate_types() override;
public
:
};
CeilingOp
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
:
BuiltinOp
({
arg0
,
arg1
})
class
Ceiling
:
public
Builtin
{
{
}
public
:
Ceiling
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
:
Builtin
({
arg0
,
arg1
})
{
}
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"c
eiling"
;
}
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"C
eiling"
;
}
//virtual void propagate_types() override;
//virtual void propagate_types() override;
};
};
class
DivideOp
:
public
BuiltinOp
class
Divide
:
public
Builtin
{
public
:
DivideOp
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
:
BuiltinOp
({
arg0
,
arg1
})
{
{
}
public
:
Divide
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
:
Builtin
({
arg0
,
arg1
})
{
}
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"d
ivide"
;
}
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"D
ivide"
;
}
//virtual void propagate_types() override;
//virtual void propagate_types() override;
};
};
class
EqualOp
:
public
BuiltinOp
class
Equal
:
public
Builtin
{
public
:
EqualOp
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
:
BuiltinOp
({
arg0
,
arg1
})
{
{
}
public
:
Equal
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
:
Builtin
({
arg0
,
arg1
})
{
}
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"e
qual"
;
}
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"E
qual"
;
}
//virtual void propagate_types() override;
//virtual void propagate_types() override;
};
};
class
ExpOp
:
public
BuiltinOp
class
Exp
:
public
Builtin
{
public
:
ExpOp
(
const
std
::
shared_ptr
<
Node
>&
arg0
)
:
BuiltinOp
({
arg0
})
{
{
}
public
:
Exp
(
const
std
::
shared_ptr
<
Node
>&
arg0
)
:
Builtin
({
arg0
})
{
}
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"e
xp"
;
}
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"E
xp"
;
}
//virtual void propagate_types() override;
//virtual void propagate_types() override;
};
};
class
FloorOp
:
public
BuiltinOp
class
Floor
:
public
Builtin
{
public
:
FloorOp
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
:
BuiltinOp
({
arg0
,
arg1
})
{
{
}
public
:
Floor
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
:
Builtin
({
arg0
,
arg1
})
{
}
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"f
loor"
;
}
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"F
loor"
;
}
//virtual void propagate_types() override;
//virtual void propagate_types() override;
};
};
class
GreaterOp
:
public
BuiltinOp
class
Greater
:
public
Builtin
{
public
:
GreaterOp
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
:
BuiltinOp
({
arg0
,
arg1
})
{
{
}
public
:
Greater
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
:
Builtin
({
arg0
,
arg1
})
{
}
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"g
reater"
;
}
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"G
reater"
;
}
//virtual void propagate_types() override;
//virtual void propagate_types() override;
};
};
class
LessOp
:
public
BuiltinOp
class
Less
:
public
Builtin
{
public
:
LessOp
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
:
BuiltinOp
({
arg0
,
arg1
})
{
{
}
public
:
Less
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
:
Builtin
({
arg0
,
arg1
})
{
}
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"l
ess"
;
}
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"L
ess"
;
}
//virtual void propagate_types() override;
//virtual void propagate_types() override;
};
};
class
LogOp
:
public
BuiltinOp
class
Log
:
public
Builtin
{
public
:
LogOp
(
const
std
::
shared_ptr
<
Node
>&
arg0
)
:
BuiltinOp
({
arg0
})
{
{
}
public
:
Log
(
const
std
::
shared_ptr
<
Node
>&
arg0
)
:
Builtin
({
arg0
})
{
}
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"l
og"
;
}
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"L
og"
;
}
//virtual void propagate_types() override;
//virtual void propagate_types() override;
};
};
class
MaximumOp
:
public
BuiltinOp
class
Maximum
:
public
Builtin
{
public
:
MaximumOp
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
:
BuiltinOp
({
arg0
,
arg1
})
{
{
}
public
:
Maximum
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
:
Builtin
({
arg0
,
arg1
})
{
}
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"m
ax"
;
}
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"M
ax"
;
}
//virtual void propagate_types() override;
//virtual void propagate_types() override;
};
};
class
MinimumOp
:
public
BuiltinOp
class
Minimum
:
public
Builtin
{
public
:
MinimumOp
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
:
BuiltinOp
({
arg0
,
arg1
})
{
{
}
public
:
Minimum
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
:
Builtin
({
arg0
,
arg1
})
{
}
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"m
in"
;
}
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"M
in"
;
}
//virtual void propagate_types() override;
//virtual void propagate_types() override;
};
};
class
MultiplyOp
:
public
BuiltinOp
class
Multiply
:
public
Builtin
{
public
:
MultiplyOp
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
:
BuiltinOp
({
arg0
,
arg1
})
{
{
}
public
:
Multiply
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
:
Builtin
({
arg0
,
arg1
})
{
}
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"m
ultiply"
;
}
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"M
ultiply"
;
}
//virtual void propagate_types() override;
//virtual void propagate_types() override;
};
};
class
NegativeOp
:
public
BuiltinOp
class
Negative
:
public
Builtin
{
public
:
NegativeOp
(
const
std
::
shared_ptr
<
Node
>&
arg0
)
:
BuiltinOp
({
arg0
})
{
{
}
public
:
Negative
(
const
std
::
shared_ptr
<
Node
>&
arg0
)
:
Builtin
({
arg0
})
{
}
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"n
egative"
;
}
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"N
egative"
;
}
//virtual void propagate_types() override;
//virtual void propagate_types() override;
};
};
class
PowerOp
:
public
BuiltinOp
class
Power
:
public
Builtin
{
public
:
PowerOp
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
:
BuiltinOp
({
arg0
,
arg1
})
{
{
}
public
:
Power
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
:
Builtin
({
arg0
,
arg1
})
{
}
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"p
ower"
;
}
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"P
ower"
;
}
//virtual void propagate_types() override;
//virtual void propagate_types() override;
};
};
class
RemainderOp
:
public
BuiltinOp
class
Remainder
:
public
Builtin
{
public
:
RemainderOp
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
:
BuiltinOp
({
arg0
,
arg1
})
{
{
}
public
:
Remainder
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
:
Builtin
({
arg0
,
arg1
})
{
}
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"r
emainder"
;
}
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"R
emainder"
;
}
//virtual void propagate_types() override;
//virtual void propagate_types() override;
};
};
class
ReshapeOp
:
public
BuiltinOp
class
Reshape
:
public
Builtin
{
public
:
ReshapeOp
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
Shape
&
shape
)
:
BuiltinOp
({
arg0
})
,
m_shape
(
shape
)
{
{
}
public
:
Reshape
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
Shape
&
shape
)
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"reshape"
;
}
:
Builtin
({
arg0
})
//virtual void propagate_types() override;
,
m_shape
(
shape
)
protected
:
{
Shape
m_shape
;
}
};
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"Reshape"
;
}
class
SubtractOp
:
public
BuiltinOp
//virtual void propagate_types() override;
{
protected
:
public
:
Shape
m_shape
;
SubtractOp
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
};
:
BuiltinOp
({
arg0
,
arg1
})
class
Subtract
:
public
Builtin
{
{
}
public
:
Subtract
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"subtract"
;
}
:
Builtin
({
arg0
,
arg1
})
//virtual void propagate_types() override;
{
};
}
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"Subtract"
;
}
//virtual void propagate_types() override;
};
}
}
}
src/ngraph/ops/broadcast.hpp
View file @
36e36e7f
...
@@ -16,36 +16,32 @@
...
@@ -16,36 +16,32 @@
namespace
ngraph
namespace
ngraph
{
{
class
BroadcastOp
:
public
BuiltinO
p
namespace
o
p
{
{
public
:
class
Broadcast
:
public
Builtin
///
/// @param arg The tensor view to be broadcast.
/// @param shape The shape of the result
/// @param broadcast_axes The axis positions (0-based) in the result that are being broadcast.
/// the remaining axes in shape must be the same as the shape of arg.
///
BroadcastOp
(
const
std
::
shared_ptr
<
Node
>&
arg
,
const
Shape
&
shape
,
const
AxisSet
&
broadcast_axes
)
:
BuiltinOp
({
arg
})
,
m_shape
(
shape
)
,
m_broadcast_axes
(
broadcast_axes
)
{
{
}
public
:
///
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"broadcast"
;
}
/// @param arg The tensor view to be broadcast.
virtual
void
propagate_types
()
override
;
/// @param shape The shape of the result
/// @param broadcast_axes The axis positions (0-based) in the result that are being broadcast.
/// the remaining axes in shape must be the same as the shape of arg.
///
Broadcast
(
const
std
::
shared_ptr
<
Node
>&
arg
,
const
Shape
&
shape
,
const
AxisSet
&
broadcast_axes
)
:
Builtin
({
arg
})
,
m_shape
(
shape
)
,
m_broadcast_axes
(
broadcast_axes
)
{
}
protected
:
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"Broadcast"
;
}
Shape
m_shape
;
virtual
void
propagate_types
()
override
;
AxisSet
m_broadcast_axes
;
};
namespace
op
protected
:
{
Shape
m_shape
;
std
::
shared_ptr
<
Node
>
broadcast
(
const
std
::
shared_ptr
<
Node
>&
tensor
,
AxisSet
m_broadcast_axes
;
const
Shape
&
shape
,
};
AxisSet
&&
broadcast_axes
);
}
}
}
}
src/ngraph/ops/concatenate.hpp
View file @
36e36e7f
...
@@ -18,18 +18,16 @@ namespace ngraph
...
@@ -18,18 +18,16 @@ namespace ngraph
{
{
namespace
op
namespace
op
{
{
std
::
shared_ptr
<
Node
>
concatenate
(
const
Nodes
&
args
);
class
Concat
:
public
Builtin
}
class
ConcatOp
:
public
BuiltinOp
{
public
:
ConcatOp
(
const
Nodes
&
args
)
:
BuiltinOp
(
args
)
{
{
}
public
:
Concat
(
const
Nodes
&
args
)
:
Builtin
(
args
)
{
}
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"concatenate"
;
}
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"Concatenate"
;
}
virtual
void
propagate_types
()
override
;
virtual
void
propagate_types
()
override
;
};
};
}
}
}
src/ngraph/ops/constant.hpp
View file @
36e36e7f
...
@@ -20,54 +20,57 @@
...
@@ -20,54 +20,57 @@
namespace
ngraph
namespace
ngraph
{
{
// Defines methods to all constant scalars
namespace
op
class
ScalarConstantBase
:
public
Node
{
{
protected
:
// Defines methods to all constant scalars
ScalarConstantBase
(
const
std
::
shared_ptr
<
TensorViewType
>&
type
)
class
ScalarConstantBase
:
public
Node
:
Node
({},
type
)
{
{
}
protected
:
ScalarConstantBase
(
const
std
::
shared_ptr
<
TensorViewType
>&
type
)
:
Node
({},
type
)
{
}
virtual
void
propagate_types
()
override
;
virtual
void
propagate_types
()
override
;
};
};
// Implement a constant scalar for each element type.
// Implement a constant scalar for each element type.
// The static make method takes a
// The static make method takes a
template
<
typename
T
>
template
<
typename
T
>
class
ScalarConstant
:
public
ScalarConstantBase
class
ScalarConstant
:
public
ScalarConstantBase
{
public
:
// The ngraph element type
using
element_type
=
T
;
// The C++ type that holds the element type
using
type
=
typename
T
::
type
;
ScalarConstant
(
typename
T
::
type
value
)
:
ScalarConstantBase
(
std
::
make_shared
<
TensorViewType
>
(
T
::
element_type
(),
Shape
{}))
,
m_value
(
value
)
{
{
}
public
:
// The ngraph element type
using
element_type
=
T
;
// The C++ type that holds the element type
using
type
=
typename
T
::
type
;
virtual
std
::
string
description
()
const
override
{
return
"ScalarConstant"
;
}
ScalarConstant
(
typename
T
::
type
value
)
virtual
std
::
string
get_node_id
()
const
override
:
ScalarConstantBase
(
std
::
make_shared
<
TensorViewType
>
(
T
::
element_type
(),
Shape
{}))
{
,
m_value
(
value
)
std
::
stringstream
ss
;
{
ss
<<
description
()
<<
"_"
/* << node_id() */
;
}
return
ss
.
str
();
}
virtual
std
::
string
description
()
const
override
{
return
"ScalarConstant"
;
}
virtual
std
::
string
get_node_id
()
const
override
{
std
::
stringstream
ss
;
ss
<<
description
()
<<
"_"
/* << node_id() */
;
return
ss
.
str
();
}
typename
T
::
type
get_value
()
const
{
return
m_value
;
}
typename
T
::
type
get_value
()
const
{
return
m_value
;
}
protected
:
protected
:
typename
T
::
type
m_value
;
typename
T
::
type
m_value
;
};
};
using
Float32ScalarConstant
=
ScalarConstant
<
element
::
Float32
>
;
using
Float32ScalarConstant
=
ScalarConstant
<
element
::
Float32
>
;
using
Int8ScalarConstant
=
ScalarConstant
<
element
::
Int8
>
;
using
Int8ScalarConstant
=
ScalarConstant
<
element
::
Int8
>
;
using
Int32ScalarConstant
=
ScalarConstant
<
element
::
Int32
>
;
using
Int32ScalarConstant
=
ScalarConstant
<
element
::
Int32
>
;
using
Int64ScalarConstant
=
ScalarConstant
<
element
::
Int64
>
;
using
Int64ScalarConstant
=
ScalarConstant
<
element
::
Int64
>
;
using
UInt8ScalarConstant
=
ScalarConstant
<
element
::
UInt8
>
;
using
UInt8ScalarConstant
=
ScalarConstant
<
element
::
UInt8
>
;
using
UInt32ScalarConstant
=
ScalarConstant
<
element
::
UInt32
>
;
using
UInt32ScalarConstant
=
ScalarConstant
<
element
::
UInt32
>
;
using
UInt64ScalarConstant
=
ScalarConstant
<
element
::
UInt64
>
;
using
UInt64ScalarConstant
=
ScalarConstant
<
element
::
UInt64
>
;
}
}
}
src/ngraph/ops/convert.hpp
View file @
36e36e7f
...
@@ -16,25 +16,22 @@
...
@@ -16,25 +16,22 @@
namespace
ngraph
namespace
ngraph
{
{
class
ConvertOp
:
public
BuiltinO
p
namespace
o
p
{
{
public
:
class
Convert
:
public
Builtin
ConvertOp
(
const
std
::
shared_ptr
<
Node
>&
arg
,
const
ngraph
::
element
::
Type
&
element_type
)
:
BuiltinOp
({
arg
})
,
m_element_type
(
element_type
)
{
{
}
public
:
Convert
(
const
std
::
shared_ptr
<
Node
>&
arg
,
const
ngraph
::
element
::
Type
&
element_type
)
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"convert"
;
}
:
Builtin
({
arg
})
virtual
void
propagate_types
()
override
;
,
m_element_type
(
element_type
)
{
}
protected
:
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"Convert"
;
}
const
ngraph
::
element
::
Type
&
m_element_type
;
virtual
void
propagate_types
()
override
;
};
namespace
op
protected
:
{
const
ngraph
::
element
::
Type
&
m_element_type
;
std
::
shared_ptr
<
ngraph
::
ConvertOp
>
convert
(
const
std
::
shared_ptr
<
Node
>&
arg
,
};
const
ngraph
::
element
::
Type
&
element_type
);
}
}
}
}
src/ngraph/ops/dot.hpp
View file @
36e36e7f
...
@@ -16,22 +16,19 @@
...
@@ -16,22 +16,19 @@
namespace
ngraph
namespace
ngraph
{
{
class
DotOp
:
public
BuiltinO
p
namespace
o
p
{
{
public
:
class
Dot
:
public
Builtin
/// TODO: Semantics of arg0 and arg1 axes wrt reduction.
DotOp
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
:
BuiltinOp
({
arg0
,
arg1
})
{
{
}
public
:
/// TODO: Semantics of arg0 and arg1 axes wrt reduction.
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"dot"
;
}
Dot
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
virtual
void
propagate_types
()
override
;
:
Builtin
({
arg0
,
arg1
})
};
{
}
namespace
op
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"Dot"
;
}
{
virtual
void
propagate_types
()
override
;
std
::
shared_ptr
<
Node
>
dot
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
};
const
std
::
shared_ptr
<
Node
>&
arg1
);
}
}
}
}
src/ngraph/ops/parameter.hpp
View file @
36e36e7f
...
@@ -20,41 +20,33 @@
...
@@ -20,41 +20,33 @@
namespace
ngraph
namespace
ngraph
{
{
class
Function
;
class
Function
;
///
/// Parameters are nodes that represent the arguments that will be passed to user-defined functions.
/// Function creation requires a sequence of parameters.
/// Basic graph operations do not need parameters attached to a function.
///
class
Parameter
:
public
Node
{
friend
class
Function
;
protected
:
// Called by the Function constructor to associate this parameter with the function.
// It is an error to try to associate a parameter with more than one function.
void
assign_function
(
Function
*
function
,
size_t
index
);
public
:
Parameter
(
const
std
::
shared_ptr
<
ValueType
>&
value_type
);
Parameter
(
const
ngraph
::
element
::
Type
element_type
,
const
Shape
&
shape
);
std
::
string
description
()
const
override
{
return
"Parameter"
;
}
virtual
void
propagate_types
()
override
;
virtual
std
::
string
get_node_id
()
const
override
;
protected
:
Function
*
m_function
;
size_t
m_index
;
};
namespace
op
namespace
op
{
{
/// Factory for frameworks
///
std
::
shared_ptr
<
ngraph
::
Parameter
>
/// Parameters are nodes that represent the arguments that will be passed to user-defined functions.
parameter
(
const
std
::
shared_ptr
<
ValueType
>&
value_type
=
nullptr
);
/// Function creation requires a sequence of parameters.
/// Convenience factory for tests
/// Basic graph operations do not need parameters attached to a function.
std
::
shared_ptr
<
ngraph
::
Parameter
>
parameter
(
const
element
::
Type
element_type
,
///
const
Shape
&
shape
);
class
Parameter
:
public
Node
{
friend
class
ngraph
::
Function
;
protected
:
// Called by the Function constructor to associate this parameter with the function.
// It is an error to try to associate a parameter with more than one function.
void
assign_function
(
Function
*
function
,
size_t
index
);
public
:
Parameter
(
const
std
::
shared_ptr
<
ValueType
>&
value_type
);
Parameter
(
const
ngraph
::
element
::
Type
element_type
,
const
Shape
&
shape
);
std
::
string
description
()
const
override
{
return
"Parameter"
;
}
virtual
void
propagate_types
()
override
;
virtual
std
::
string
get_node_id
()
const
override
;
protected
:
Function
*
m_function
;
size_t
m_index
;
};
}
}
}
}
src/ngraph/ops/tuple.hpp
View file @
36e36e7f
...
@@ -18,18 +18,16 @@ namespace ngraph
...
@@ -18,18 +18,16 @@ namespace ngraph
{
{
namespace
op
namespace
op
{
{
std
::
shared_ptr
<
Node
>
tuple
(
const
Nodes
&
args
);
class
Tuple
:
public
Builtin
}
class
TupleOp
:
public
BuiltinOp
{
public
:
TupleOp
(
const
Nodes
&
args
)
:
BuiltinOp
(
args
)
{
{
}
public
:
Tuple
(
const
Nodes
&
args
)
:
Builtin
(
args
)
{
}
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"tuple"
;
}
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"Tuple"
;
}
virtual
void
propagate_types
()
override
;
virtual
void
propagate_types
()
override
;
};
};
}
}
}
src/ops/broadcast.cpp
View file @
36e36e7f
...
@@ -15,20 +15,9 @@
...
@@ -15,20 +15,9 @@
#include "ngraph/ngraph.hpp"
#include "ngraph/ngraph.hpp"
using
namespace
std
;
using
namespace
std
;
using
namespace
ngraph
;
using
namespace
ngraph
::
op
;
/// @param tensor The tensor view to be broadcast.
void
Broadcast
::
propagate_types
()
/// @param shape The shape of the result
/// @param broadcast_axes The axis positions (0-based) in the result that are being broadcast.
/// the remaining axes in shape must be the same as the shape of arg.
std
::
shared_ptr
<
Node
>
ngraph
::
op
::
broadcast
(
const
std
::
shared_ptr
<
Node
>&
tensor
,
const
Shape
&
shape
,
AxisSet
&&
broadcast_axes
)
{
return
make_shared
<
BroadcastOp
>
(
tensor
,
shape
,
broadcast_axes
);
}
void
BroadcastOp
::
propagate_types
()
{
{
auto
arg_type
=
m_arguments
.
at
(
0
)
->
get_value_type
();
auto
arg_type
=
m_arguments
.
at
(
0
)
->
get_value_type
();
if
(
nullptr
==
arg_type
)
if
(
nullptr
==
arg_type
)
...
...
src/ops/concatenate.cpp
View file @
36e36e7f
...
@@ -17,14 +17,9 @@
...
@@ -17,14 +17,9 @@
#include "ngraph/ngraph.hpp"
#include "ngraph/ngraph.hpp"
using
namespace
std
;
using
namespace
std
;
using
namespace
ngraph
;
using
namespace
ngraph
::
op
;
void
Concat
Op
::
propagate_types
()
void
Concat
::
propagate_types
()
{
{
throw
ngraph_error
(
"NIY"
);
throw
ngraph_error
(
"NIY"
);
}
}
std
::
shared_ptr
<
Node
>
op
::
concatenate
(
const
std
::
vector
<
std
::
shared_ptr
<
Node
>>&
args
)
{
return
make_shared
<
ConcatOp
>
(
args
);
}
src/ops/constant.cpp
View file @
36e36e7f
...
@@ -14,6 +14,6 @@
...
@@ -14,6 +14,6 @@
#include "ngraph/ngraph.hpp"
#include "ngraph/ngraph.hpp"
using
namespace
ngraph
;
using
namespace
ngraph
::
op
;
void
ScalarConstantBase
::
propagate_types
()
{}
void
ScalarConstantBase
::
propagate_types
()
{}
src/ops/convert.cpp
View file @
36e36e7f
...
@@ -17,15 +17,9 @@
...
@@ -17,15 +17,9 @@
#include "ngraph/ngraph.hpp"
#include "ngraph/ngraph.hpp"
using
namespace
std
;
using
namespace
std
;
using
namespace
ngraph
;
using
namespace
ngraph
::
op
;
void
Convert
Op
::
propagate_types
()
void
Convert
::
propagate_types
()
{
{
throw
ngraph_error
(
"NIY"
);
throw
ngraph_error
(
"NIY"
);
}
}
shared_ptr
<
ConvertOp
>
op
::
convert
(
const
std
::
shared_ptr
<
Node
>&
arg
,
const
element
::
Type
&
element_type
)
{
return
make_shared
<
ConvertOp
>
(
arg
,
element_type
);
}
src/ops/dot.cpp
View file @
36e36e7f
...
@@ -17,16 +17,9 @@
...
@@ -17,16 +17,9 @@
#include "ngraph/ngraph.hpp"
#include "ngraph/ngraph.hpp"
using
namespace
std
;
using
namespace
std
;
using
namespace
ngraph
;
using
namespace
ngraph
::
op
;
/// TODO: Semantics of arg0 and arg1 axes wrt reduction.
void
Dot
::
propagate_types
()
std
::
shared_ptr
<
Node
>
ngraph
::
op
::
dot
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
{
return
make_shared
<
DotOp
>
(
arg0
,
arg1
);
}
void
DotOp
::
propagate_types
()
{
{
auto
arg0_tensor_type
=
auto
arg0_tensor_type
=
dynamic_pointer_cast
<
TensorViewType
>
(
m_arguments
.
at
(
0
)
->
get_value_type
());
dynamic_pointer_cast
<
TensorViewType
>
(
m_arguments
.
at
(
0
)
->
get_value_type
());
...
...
src/ops/function.cpp
View file @
36e36e7f
...
@@ -18,7 +18,7 @@ using namespace std;
...
@@ -18,7 +18,7 @@ using namespace std;
using
namespace
ngraph
;
using
namespace
ngraph
;
Function
::
Function
(
const
std
::
shared_ptr
<
Node
>&
result
,
Function
::
Function
(
const
std
::
shared_ptr
<
Node
>&
result
,
const
std
::
vector
<
std
::
shared_ptr
<
ngraph
::
Parameter
>>&
parameters
)
const
std
::
vector
<
std
::
shared_ptr
<
op
::
Parameter
>>&
parameters
)
:
m_result
(
result
)
:
m_result
(
result
)
,
m_parameters
(
parameters
)
,
m_parameters
(
parameters
)
,
m_name
(
"Function"
)
,
m_name
(
"Function"
)
...
@@ -29,15 +29,3 @@ Function::Function(const std::shared_ptr<Node>& result
...
@@ -29,15 +29,3 @@ Function::Function(const std::shared_ptr<Node>& result
parameter
->
assign_function
(
this
,
i
++
);
parameter
->
assign_function
(
this
,
i
++
);
}
}
}
}
shared_ptr
<
Function
>
ngraph
::
op
::
function
(
const
std
::
shared_ptr
<
Node
>&
result
,
const
initializer_list
<
shared_ptr
<
Parameter
>>&
parameters
)
{
return
make_shared
<
Function
>
(
result
,
parameters
);
}
shared_ptr
<
Function
>
ngraph
::
op
::
function
(
const
std
::
shared_ptr
<
Node
>&
result
,
const
vector
<
shared_ptr
<
Parameter
>>&
parameters
)
{
return
make_shared
<
Function
>
(
result
,
parameters
);
}
src/ops/op.cpp
View file @
36e36e7f
...
@@ -26,103 +26,3 @@ std::string ngraph::Op::get_node_id() const
...
@@ -26,103 +26,3 @@ std::string ngraph::Op::get_node_id() const
ss
<<
get_op_class_name
()
<<
"_"
<<
m_instance_id
;
ss
<<
get_op_class_name
()
<<
"_"
<<
m_instance_id
;
return
ss
.
str
();
return
ss
.
str
();
}
}
std
::
shared_ptr
<
Node
>
ngraph
::
op
::
abs
(
const
std
::
shared_ptr
<
Node
>&
arg
)
{
return
make_shared
<
AbsOp
>
(
arg
);
}
std
::
shared_ptr
<
Node
>
ngraph
::
op
::
add
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
{
return
make_shared
<
AddOp
>
(
arg0
,
arg1
);
}
std
::
shared_ptr
<
Node
>
ngraph
::
op
::
ceiling
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
{
return
make_shared
<
CeilingOp
>
(
arg0
,
arg1
);
}
// 'convert',
// 'convolution',
std
::
shared_ptr
<
Node
>
ngraph
::
op
::
divide
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
{
return
make_shared
<
DivideOp
>
(
arg0
,
arg1
);
}
std
::
shared_ptr
<
Node
>
ngraph
::
op
::
exp
(
const
std
::
shared_ptr
<
Node
>&
arg0
)
{
return
make_shared
<
ExpOp
>
(
arg0
);
}
std
::
shared_ptr
<
Node
>
ngraph
::
op
::
floor
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
{
return
make_shared
<
FloorOp
>
(
arg0
,
arg1
);
}
std
::
shared_ptr
<
Node
>
ngraph
::
op
::
log
(
const
std
::
shared_ptr
<
Node
>&
arg0
)
{
return
make_shared
<
LogOp
>
(
arg0
);
}
std
::
shared_ptr
<
Node
>
ngraph
::
op
::
maximum
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
{
return
make_shared
<
MaximumOp
>
(
arg0
,
arg1
);
}
std
::
shared_ptr
<
Node
>
ngraph
::
op
::
minimum
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
{
return
make_shared
<
MinimumOp
>
(
arg0
,
arg1
);
}
std
::
shared_ptr
<
Node
>
ngraph
::
op
::
multiply
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
{
return
make_shared
<
MultiplyOp
>
(
arg0
,
arg1
);
}
std
::
shared_ptr
<
Node
>
ngraph
::
op
::
negative
(
const
std
::
shared_ptr
<
Node
>&
arg0
)
{
return
make_shared
<
NegativeOp
>
(
arg0
);
}
// 'pad',
std
::
shared_ptr
<
Node
>
ngraph
::
op
::
power
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
{
return
make_shared
<
PowerOp
>
(
arg0
,
arg1
);
}
//'reduce',
std
::
shared_ptr
<
Node
>
ngraph
::
op
::
remainder
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
{
return
make_shared
<
RemainderOp
>
(
arg0
,
arg1
);
}
std
::
shared_ptr
<
Node
>
ngraph
::
op
::
reshape
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
Shape
&
shape
)
{
return
make_shared
<
ReshapeOp
>
(
arg0
,
shape
);
}
//'reverse',
//'rng',
// 'select',
//'slice',
std
::
shared_ptr
<
Node
>
ngraph
::
op
::
subtract
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
{
return
make_shared
<
SubtractOp
>
(
arg0
,
arg1
);
}
// 'transpose',
// 'while'
src/ops/parameter.cpp
View file @
36e36e7f
...
@@ -17,7 +17,7 @@
...
@@ -17,7 +17,7 @@
#include "ngraph/ngraph.hpp"
#include "ngraph/ngraph.hpp"
using
namespace
std
;
using
namespace
std
;
using
namespace
ngraph
;
using
namespace
ngraph
::
op
;
Parameter
::
Parameter
(
const
std
::
shared_ptr
<
ValueType
>&
value_type
)
Parameter
::
Parameter
(
const
std
::
shared_ptr
<
ValueType
>&
value_type
)
:
Node
(
value_type
)
:
Node
(
value_type
)
...
@@ -43,18 +43,7 @@ void Parameter::assign_function(Function* function, size_t index)
...
@@ -43,18 +43,7 @@ void Parameter::assign_function(Function* function, size_t index)
void
Parameter
::
propagate_types
()
{}
void
Parameter
::
propagate_types
()
{}
shared_ptr
<
Parameter
>
ngraph
::
op
::
parameter
(
const
std
::
shared_ptr
<
ValueType
>&
value_type
)
std
::
string
ngraph
::
op
::
Parameter
::
get_node_id
()
const
{
return
make_shared
<
Parameter
>
(
value_type
);
}
shared_ptr
<
Parameter
>
ngraph
::
op
::
parameter
(
const
ngraph
::
element
::
Type
element_type
,
const
Shape
&
shape
)
{
return
make_shared
<
Parameter
>
(
make_shared
<
TensorViewType
>
(
element_type
,
shape
));
}
std
::
string
ngraph
::
Parameter
::
get_node_id
()
const
{
{
stringstream
ss
;
stringstream
ss
;
ss
<<
"parameter_"
<<
m_instance_id
;
ss
<<
"parameter_"
<<
m_instance_id
;
...
...
src/ops/tuple.cpp
View file @
36e36e7f
...
@@ -17,14 +17,9 @@
...
@@ -17,14 +17,9 @@
#include "ngraph/ngraph.hpp"
#include "ngraph/ngraph.hpp"
using
namespace
std
;
using
namespace
std
;
using
namespace
ngraph
;
using
namespace
ngraph
::
op
;
void
Tuple
Op
::
propagate_types
()
void
Tuple
::
propagate_types
()
{
{
throw
ngraph_error
(
"NIY"
);
throw
ngraph_error
(
"NIY"
);
}
}
std
::
shared_ptr
<
Node
>
op
::
tuple
(
const
std
::
vector
<
std
::
shared_ptr
<
Node
>>&
args
)
{
return
make_shared
<
TupleOp
>
(
args
);
}
test/build_graph.cpp
View file @
36e36e7f
...
@@ -23,17 +23,17 @@ using namespace ngraph;
...
@@ -23,17 +23,17 @@ using namespace ngraph;
TEST
(
build_graph
,
build_simple
)
TEST
(
build_graph
,
build_simple
)
{
{
// Function with 4 parameters
// Function with 4 parameters
auto
arg0
=
node
<
Parameter
>
(
element
::
Float32
::
element_type
(),
Shape
{
7
,
3
});
auto
arg0
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
Shape
{
7
,
3
});
auto
arg1
=
node
<
Parameter
>
(
element
::
Float32
::
element_type
(),
Shape
{
3
});
auto
arg1
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
Shape
{
3
});
auto
arg2
=
node
<
Parameter
>
(
element
::
Float32
::
element_type
(),
Shape
{
32
,
7
});
auto
arg2
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
Shape
{
32
,
7
});
auto
arg3
=
node
<
Parameter
>
(
element
::
Float32
::
element_type
(),
Shape
{
32
,
7
});
auto
arg3
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
Shape
{
32
,
7
});
auto
broadcast_1
=
node
<
BroadcastOp
>
(
arg3
,
Shape
{
10
,
32
,
7
},
AxisSet
{
0
});
auto
broadcast_1
=
make_shared
<
op
::
Broadcast
>
(
arg3
,
Shape
{
10
,
32
,
7
},
AxisSet
{
0
});
auto
b1
=
node
<
BroadcastOp
>
(
arg3
,
Shape
{
10
,
32
,
7
},
AxisSet
{
0
});
auto
b1
=
make_shared
<
op
::
Broadcast
>
(
arg3
,
Shape
{
10
,
32
,
7
},
AxisSet
{
0
});
auto
dot
=
node
<
DotOp
>
(
arg2
,
arg0
);
auto
dot
=
make_shared
<
op
::
Dot
>
(
arg2
,
arg0
);
ASSERT_EQ
(
dot
->
get_arguments
()[
0
],
arg2
);
ASSERT_EQ
(
dot
->
get_arguments
()[
0
],
arg2
);
ASSERT_EQ
(
dot
->
get_arguments
()[
1
],
arg0
);
ASSERT_EQ
(
dot
->
get_arguments
()[
1
],
arg0
);
auto
cluster_0
=
op
::
function
(
dot
,
{
arg0
,
arg1
,
arg2
,
arg3
});
auto
cluster_0
=
make_shared
<
Function
>
(
dot
,
op
::
Parameters
{
arg0
,
arg1
,
arg2
,
arg3
});
ASSERT_EQ
(
cluster_0
->
get_result
(),
dot
);
ASSERT_EQ
(
cluster_0
->
get_result
(),
dot
);
}
}
...
@@ -59,15 +59,15 @@ TEST(build_graph, as_type)
...
@@ -59,15 +59,15 @@ TEST(build_graph, as_type)
// Check node comparisons
// Check node comparisons
TEST
(
build_graph
,
node_comparison
)
TEST
(
build_graph
,
node_comparison
)
{
{
auto
arg0
=
node
<
Parameter
>
(
element
::
Float32
::
element_type
(),
Shape
{
32
,
3
});
auto
arg0
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
Shape
{
32
,
3
});
auto
arg1
=
node
<
Parameter
>
(
element
::
Float32
::
element_type
(),
Shape
{
3
});
auto
arg1
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
Shape
{
3
});
auto
arg2
=
node
<
Parameter
>
(
element
::
Float32
::
element_type
(),
Shape
{
32
});
auto
arg2
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
Shape
{
32
});
auto
dot
=
op
::
dot
(
arg0
,
arg1
);
auto
dot
=
make_shared
<
op
::
Dot
>
(
arg0
,
arg1
);
auto
add
=
op
::
add
(
dot
,
arg2
);
auto
add
=
make_shared
<
op
::
Add
>
(
dot
,
arg2
);
auto
parg
=
node
<
Parameter
>
(
element
::
Float32
::
element_type
(),
Shape
{});
auto
parg
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
Shape
{});
auto
pattern_dot
=
node
<
DotOp
>
(
parg
,
parg
);
auto
pattern_dot
=
make_shared
<
op
::
Dot
>
(
parg
,
parg
);
ASSERT_TRUE
(
pattern_dot
->
is_same_op_type
(
dot
));
ASSERT_TRUE
(
pattern_dot
->
is_same_op_type
(
dot
));
// TODO This passes because typeid is not behaving as documented.
// TODO This passes because typeid is not behaving as documented.
// Need to figure out what's wrong.
// Need to figure out what's wrong.
...
@@ -78,20 +78,20 @@ TEST(build_graph, literal)
...
@@ -78,20 +78,20 @@ TEST(build_graph, literal)
{
{
// float scalar from a float
// float scalar from a float
//auto float0 = FloatScalarConstant::make(3.0);
//auto float0 = FloatScalarConstant::make(3.0);
auto
float0
=
node
<
Float32ScalarConstant
>
(
3.0
);
auto
float0
=
make_shared
<
op
::
Float32ScalarConstant
>
(
3.0
);
auto
float_scalar_type
=
make_shared
<
TensorViewType
>
(
element
::
Float32
::
element_type
(),
Shape
{});
auto
float_scalar_type
=
make_shared
<
TensorViewType
>
(
element
::
Float32
::
element_type
(),
Shape
{});
ASSERT_EQ
(
float0
->
get_value
(),
3.0
);
ASSERT_EQ
(
float0
->
get_value
(),
3.0
);
ASSERT_EQ
(
*
float0
->
get_value_type
(),
float_scalar_type
);
ASSERT_EQ
(
*
float0
->
get_value_type
(),
float_scalar_type
);
auto
d
=
node
<
DotOp
>
(
float0
,
float0
);
auto
d
=
make_shared
<
op
::
Dot
>
(
float0
,
float0
);
ASSERT_EQ
(
d
->
get_arguments
().
at
(
0
),
float0
);
ASSERT_EQ
(
d
->
get_arguments
().
at
(
0
),
float0
);
ASSERT_EQ
(
d
->
get_arguments
().
at
(
1
),
float0
);
ASSERT_EQ
(
d
->
get_arguments
().
at
(
1
),
float0
);
// float scalar from an int
// float scalar from an int
auto
float1
=
node
<
Float32ScalarConstant
>
(
3
);
auto
float1
=
make_shared
<
op
::
Float32ScalarConstant
>
(
3
);
ASSERT_EQ
(
float1
->
get_value
(),
3
);
ASSERT_EQ
(
float1
->
get_value
(),
3
);
ASSERT_EQ
(
*
float1
->
get_value_type
(),
float_scalar_type
);
ASSERT_EQ
(
*
float1
->
get_value_type
(),
float_scalar_type
);
auto
int32_0
=
node
<
Int32ScalarConstant
>
(
3.0
);
auto
int32_0
=
make_shared
<
op
::
Int32ScalarConstant
>
(
3.0
);
auto
int32_scalar_type
=
make_shared
<
TensorViewType
>
(
element
::
Int32
::
element_type
(),
Shape
{});
auto
int32_scalar_type
=
make_shared
<
TensorViewType
>
(
element
::
Int32
::
element_type
(),
Shape
{});
ASSERT_EQ
(
int32_0
->
get_value
(),
3
);
ASSERT_EQ
(
int32_0
->
get_value
(),
3
);
ASSERT_EQ
(
*
int32_0
->
get_value_type
(),
int32_scalar_type
);
ASSERT_EQ
(
*
int32_0
->
get_value_type
(),
int32_scalar_type
);
...
...
test/op.cpp
View file @
36e36e7f
...
@@ -23,7 +23,7 @@ using namespace ngraph;
...
@@ -23,7 +23,7 @@ using namespace ngraph;
TEST
(
op
,
is_op
)
TEST
(
op
,
is_op
)
{
{
auto
arg0
=
op
::
parameter
(
element
::
Float32
::
element_type
(),
{
1
});
auto
arg0
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
Shape
{
1
});
ASSERT_NE
(
nullptr
,
arg0
);
ASSERT_NE
(
nullptr
,
arg0
);
EXPECT_TRUE
(
arg0
->
is_parameter
());
EXPECT_TRUE
(
arg0
->
is_parameter
());
EXPECT_FALSE
(
arg0
->
is_op
());
EXPECT_FALSE
(
arg0
->
is_op
());
...
@@ -31,9 +31,9 @@ TEST(op, is_op)
...
@@ -31,9 +31,9 @@ TEST(op, is_op)
TEST
(
op
,
is_parameter
)
TEST
(
op
,
is_parameter
)
{
{
auto
arg0
=
op
::
parameter
(
element
::
Float32
::
element_type
(),
{
1
});
auto
arg0
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
Shape
{
1
});
ASSERT_NE
(
nullptr
,
arg0
);
ASSERT_NE
(
nullptr
,
arg0
);
auto
t0
=
op
::
add
(
arg0
,
arg0
);
auto
t0
=
make_shared
<
op
::
Add
>
(
arg0
,
arg0
);
ASSERT_NE
(
nullptr
,
t0
);
ASSERT_NE
(
nullptr
,
t0
);
EXPECT_FALSE
(
t0
->
is_parameter
());
EXPECT_FALSE
(
t0
->
is_parameter
());
EXPECT_TRUE
(
t0
->
is_op
());
EXPECT_TRUE
(
t0
->
is_op
());
...
...
test/topological_sort.cpp
View file @
36e36e7f
...
@@ -58,30 +58,30 @@ static bool validate_list(const vector<Node*>& nodes)
...
@@ -58,30 +58,30 @@ static bool validate_list(const vector<Node*>& nodes)
TEST
(
topological_sort
,
basic
)
TEST
(
topological_sort
,
basic
)
{
{
vector
<
shared_ptr
<
Parameter
>>
args
;
vector
<
shared_ptr
<
op
::
Parameter
>>
args
;
for
(
int
i
=
0
;
i
<
10
;
i
++
)
for
(
int
i
=
0
;
i
<
10
;
i
++
)
{
{
auto
arg
=
op
::
parameter
(
element
::
Float32
::
element_type
(),
{
1
});
auto
arg
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
Shape
{
1
});
ASSERT_NE
(
nullptr
,
arg
);
ASSERT_NE
(
nullptr
,
arg
);
args
.
push_back
(
arg
);
args
.
push_back
(
arg
);
}
}
auto
t0
=
op
::
add
(
args
[
0
],
args
[
1
]);
auto
t0
=
make_shared
<
op
::
Add
>
(
args
[
0
],
args
[
1
]);
ASSERT_NE
(
nullptr
,
t0
);
ASSERT_NE
(
nullptr
,
t0
);
auto
t1
=
op
::
dot
(
t0
,
args
[
2
]);
auto
t1
=
make_shared
<
op
::
Dot
>
(
t0
,
args
[
2
]);
ASSERT_NE
(
nullptr
,
t1
);
ASSERT_NE
(
nullptr
,
t1
);
auto
t2
=
op
::
multiply
(
t0
,
args
[
3
]);
auto
t2
=
make_shared
<
op
::
Multiply
>
(
t0
,
args
[
3
]);
ASSERT_NE
(
nullptr
,
t2
);
ASSERT_NE
(
nullptr
,
t2
);
auto
t3
=
op
::
add
(
t1
,
args
[
4
]);
auto
t3
=
make_shared
<
op
::
Add
>
(
t1
,
args
[
4
]);
ASSERT_NE
(
nullptr
,
t2
);
ASSERT_NE
(
nullptr
,
t2
);
auto
t4
=
op
::
add
(
t2
,
args
[
5
]);
auto
t4
=
make_shared
<
op
::
Add
>
(
t2
,
args
[
5
]);
ASSERT_NE
(
nullptr
,
t3
);
ASSERT_NE
(
nullptr
,
t3
);
auto
r0
=
op
::
add
(
t3
,
t4
);
auto
r0
=
make_shared
<
op
::
Add
>
(
t3
,
t4
);
ASSERT_NE
(
nullptr
,
r0
);
ASSERT_NE
(
nullptr
,
r0
);
auto
f0
=
op
::
function
(
r0
,
args
);
auto
f0
=
make_shared
<
Function
>
(
r0
,
args
);
ASSERT_NE
(
nullptr
,
f0
);
ASSERT_NE
(
nullptr
,
f0
);
ASSERT_EQ
(
2
,
r0
->
get_arguments
().
size
());
ASSERT_EQ
(
2
,
r0
->
get_arguments
().
size
());
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment