Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
973b3a0e
Commit
973b3a0e
authored
Sep 01, 2017
by
Robert Kimball
Committed by
GitHub
Sep 01, 2017
Browse files
Options
Browse Files
Download
Plain Diff
Merge pull request #74 from NervanaSystems/cyphers/names
Cyphers/names
parents
fac27c37
fd881acc
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
35 changed files
with
369 additions
and
385 deletions
+369
-385
clang_4_0_flags.cmake
cmake/clang_4_0_flags.cmake
+1
-0
common.hpp
src/ngraph/common.hpp
+17
-4
element_type.hpp
src/ngraph/element_type.hpp
+55
-18
function.hpp
src/ngraph/function.hpp
+12
-10
node.cpp
src/ngraph/node.cpp
+7
-6
node.hpp
src/ngraph/node.hpp
+44
-29
op.hpp
src/ngraph/op.hpp
+0
-0
broadcast.hpp
src/ngraph/ops/broadcast.hpp
+15
-15
concatenate.hpp
src/ngraph/ops/concatenate.hpp
+4
-4
constant.hpp
src/ngraph/ops/constant.hpp
+18
-26
convert.hpp
src/ngraph/ops/convert.hpp
+5
-5
dot.hpp
src/ngraph/ops/dot.hpp
+4
-3
parameter.hpp
src/ngraph/ops/parameter.hpp
+12
-10
tuple.hpp
src/ngraph/ops/tuple.hpp
+2
-2
shape.hpp
src/ngraph/shape.hpp
+4
-7
topological_sort.cpp
src/ngraph/topological_sort.cpp
+6
-7
topological_sort.hpp
src/ngraph/topological_sort.hpp
+4
-3
type.hpp
src/ngraph/type.hpp
+21
-83
visualize.cpp
src/ngraph/visualize.cpp
+8
-9
broadcast.cpp
src/ops/broadcast.cpp
+10
-12
concatenate.cpp
src/ops/concatenate.cpp
+3
-3
constant.cpp
src/ops/constant.cpp
+1
-1
convert.cpp
src/ops/convert.cpp
+2
-1
dot.cpp
src/ops/dot.cpp
+10
-7
function.cpp
src/ops/function.cpp
+3
-3
op.cpp
src/ops/op.cpp
+29
-19
parameter.cpp
src/ops/parameter.cpp
+10
-7
tuple.cpp
src/ops/tuple.cpp
+1
-1
element_type.cpp
src/types/element_type.cpp
+0
-26
type.cpp
src/types/type.cpp
+5
-5
util.cpp
src/util.cpp
+5
-5
build_graph.cpp
test/build_graph.cpp
+38
-38
op.cpp
test/op.cpp
+2
-2
topological_sort.cpp
test/topological_sort.cpp
+10
-11
util.cpp
test/util.cpp
+1
-3
No files found.
cmake/clang_4_0_flags.cmake
View file @
973b3a0e
...
...
@@ -39,6 +39,7 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-padded")
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-Wno-potentially-evaluated-expression"
)
# Triggers false alarms on typeid
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-Wno-sign-compare"
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-Wno-unused-parameter"
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-Wno-weak-vtables"
)
# Not ready for this yet
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-conversion")
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-float-equal")
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-duplicate-enum") # from numpy
...
...
src/ngraph/common.hpp
View file @
973b3a0e
...
...
@@ -15,20 +15,33 @@
#pragma once
#include <memory>
#include <vector>
#include <set>
#include <vector>
// Names for types that aren't worth giving their own classes
namespace
ngraph
{
class
Node
;
class
Parameter
;
class
ValueType
;
template
<
typename
T
,
typename
...
A
>
std
::
shared_ptr
<
T
>
node
(
A
&&
...
args
)
{
return
std
::
make_shared
<
T
>
(
args
...);
}
/// Zero or more value types
using
ValueTypes
=
std
::
vector
<
std
::
shared_ptr
<
ValueType
>>
;
/// Zero or more nodes
using
Nodes
=
std
::
vector
<
std
::
shared_ptr
<
Node
>>
;
/// A set of indices, for example, reduction axes
using
IndexSet
=
std
::
set
<
size_t
>
;
/// A sequence of axes
using
AxisVector
=
std
::
vector
<
size_t
>
;
/// A set of axes, for example, reduction axes
using
AxisSet
=
std
::
set
<
size_t
>
;
/// A list of parameters
using
Parameters
=
std
::
vector
<
std
::
shared_ptr
<
Parameter
>>
;
...
...
src/ngraph/element_type.hpp
View file @
973b3a0e
...
...
@@ -22,6 +22,8 @@
#include <string>
#include <type_traits>
#include "except.hpp"
namespace
ngraph
{
namespace
element
...
...
@@ -41,42 +43,77 @@ namespace ngraph
bool
operator
==
(
const
Type
&
other
)
const
;
bool
operator
!=
(
const
Type
&
other
)
const
{
return
!
(
*
this
==
other
);
}
private
:
static
std
::
map
<
std
::
string
,
Type
>
m_element_list
;
size_t
m_bitwidth
;
bool
m_is_float
;
bool
m_is_signed
;
const
std
::
string
m_cname
;
size_t
m_bitwidth
;
bool
m_is_float
;
bool
m_is_signed
;
const
std
::
string
&
m_cname
;
};
// Provides a compile-time name for a C++ type.
// Used in TraitedType for the string that supplies the C++ type name during code generation,
// so it needs to be a valid C++ name.
template
<
typename
T
>
const
char
*
traited_type_name
()
{
throw
ngraph_error
(
"Unknown type"
);
}
// Define a type string for a type T. Will make traited_type_name<T>() return "T"
#define NGRAPH_DEFINE_TRAITED_TYPE_NAME(T) \
template <> \
constexpr const char* traited_type_name<T>() \
{ \
return #T; \
}
// Literals (and probably other things we don't know about yet) need to have their C++ types
// and element types coordinated. Every element type corresponds to a TraitedType which provides
// access to both the instance and the C++ type used to hold the value during compilation.
template
<
typename
T
>
class
TraitedType
:
public
Type
{
public
:
// This is the C++ type used to hold a value of this element type during compilation
using
ctype
=
T
;
// This is a reference to an instance of this element type.
static
const
TraitedType
<
T
>&
type
;
TraitedType
(
const
std
::
string
&
cname
)
protected
:
TraitedType
()
:
Type
(
sizeof
(
T
)
*
8
,
std
::
is_floating_point
<
T
>::
value
,
std
::
is_signed
<
T
>::
value
,
cname
)
traited_type_name
<
T
>
()
)
{
}
public
:
// This is the C++ type used to hold a value of this element type during compilation
using
type
=
T
;
// This returns a reference to an instance of this element type.
static
const
TraitedType
<
T
>&
element_type
()
{
static
TraitedType
<
T
>
t
;
return
t
;
}
};
// Human-readable names for the element types
using
Float
=
TraitedType
<
float
>
;
using
Int8
=
TraitedType
<
int8_t
>
;
using
Int32
=
TraitedType
<
int32_t
>
;
using
Int64
=
TraitedType
<
int64_t
>
;
using
UInt8
=
TraitedType
<
uint8_t
>
;
NGRAPH_DEFINE_TRAITED_TYPE_NAME
(
float
)
using
Float32
=
TraitedType
<
float
>
;
NGRAPH_DEFINE_TRAITED_TYPE_NAME
(
int8_t
)
using
Int8
=
TraitedType
<
int8_t
>
;
NGRAPH_DEFINE_TRAITED_TYPE_NAME
(
int32_t
)
using
Int32
=
TraitedType
<
int32_t
>
;
NGRAPH_DEFINE_TRAITED_TYPE_NAME
(
int64_t
)
using
Int64
=
TraitedType
<
int64_t
>
;
NGRAPH_DEFINE_TRAITED_TYPE_NAME
(
uint8_t
)
using
UInt8
=
TraitedType
<
uint8_t
>
;
NGRAPH_DEFINE_TRAITED_TYPE_NAME
(
uint32_t
)
using
UInt32
=
TraitedType
<
uint32_t
>
;
NGRAPH_DEFINE_TRAITED_TYPE_NAME
(
uint64_t
)
using
UInt64
=
TraitedType
<
uint64_t
>
;
}
}
src/ngraph/function.hpp
View file @
973b3a0e
...
...
@@ -21,20 +21,22 @@
namespace
ngraph
{
/**
** A user-defined function.
**/
/// A user-defined function.
class
Function
{
public
:
Function
(
const
Node
::
ptr
&
result
,
Function
(
const
std
::
shared_ptr
<
Node
>&
result
,
const
std
::
vector
<
std
::
shared_ptr
<
Parameter
>>&
parameters
);
Node
::
ptr
result
()
{
return
m_result
;
}
Parameter
::
ptr
parameter
(
size_t
i
)
{
return
m_parameters
[
i
];
}
std
::
string
name
()
const
{
return
m_name
;
}
std
::
shared_ptr
<
Node
>
get_result
()
{
return
m_result
;
}
const
std
::
vector
<
std
::
shared_ptr
<
Parameter
>>
get_parameters
()
const
{
return
m_parameters
;
}
std
::
string
get_name
()
const
{
return
m_name
;
}
protected
:
Node
::
ptr
m_result
;
std
::
shared_ptr
<
Node
>
m_result
;
std
::
vector
<
std
::
shared_ptr
<
ngraph
::
Parameter
>>
m_parameters
;
std
::
string
m_name
;
};
...
...
@@ -42,10 +44,10 @@ namespace ngraph
namespace
op
{
std
::
shared_ptr
<
Function
>
function
(
const
Node
::
ptr
&
result
,
function
(
const
std
::
shared_ptr
<
Node
>&
result
,
const
std
::
initializer_list
<
std
::
shared_ptr
<
Parameter
>>&
parameters
);
std
::
shared_ptr
<
Function
>
function
(
const
Node
::
ptr
&
result
,
function
(
const
std
::
shared_ptr
<
Node
>&
result
,
const
std
::
vector
<
std
::
shared_ptr
<
Parameter
>>&
parameters
);
}
}
src/ngraph/node.cpp
View file @
973b3a0e
...
...
@@ -17,9 +17,10 @@
size_t
ngraph
::
Node
::
m_next_instance_id
=
0
;
ngraph
::
Node
::
Node
(
const
std
::
vector
<
Node
::
ptr
>&
arguments
,
ValueType
::
ptr
type
)
:
TypedValueMixin
(
type
)
,
m_arguments
(
arguments
)
ngraph
::
Node
::
Node
(
const
std
::
vector
<
std
::
shared_ptr
<
Node
>>&
arguments
,
std
::
shared_ptr
<
ValueType
>
value_type
)
:
m_arguments
(
arguments
)
,
m_value_type
(
value_type
)
,
m_instance_id
(
m_next_instance_id
++
)
{
// Add this node as a user of each argument.
...
...
@@ -47,15 +48,15 @@ namespace ngraph
auto
parameter_tmp
=
dynamic_cast
<
const
ngraph
::
Op
*>
(
&
node
);
if
(
op_tmp
)
{
out
<<
"Op("
<<
op_tmp
->
node_id
()
<<
")"
;
out
<<
"Op("
<<
op_tmp
->
get_
node_id
()
<<
")"
;
}
else
if
(
parameter_tmp
)
{
out
<<
"Parameter("
<<
parameter_tmp
->
node_id
()
<<
")"
;
out
<<
"Parameter("
<<
parameter_tmp
->
get_
node_id
()
<<
")"
;
}
else
{
out
<<
"Node("
<<
node
.
node_id
()
<<
")"
;
out
<<
"Node("
<<
node
.
get_
node_id
()
<<
")"
;
}
return
out
;
}
...
...
src/ngraph/node.hpp
View file @
973b3a0e
...
...
@@ -20,27 +20,32 @@
#include <iostream>
#include "type.hpp"
#include "common.hpp"
#include "type.hpp"
namespace
ngraph
{
class
Op
;
/**
** Nodes are the backbone of the graph of Value dataflow. Every node has
** zero or more nodes as arguments and one value, which is either a tensor
** view or a (possibly empty) tuple of values.
**/
class
Node
:
public
TypedValueMixin
,
public
std
::
enable_shared_from_this
<
Node
>
/// Nodes are the backbone of the graph of Value dataflow. Every node has
/// zero or more nodes as arguments and one value, which is either a tensor
/// view or a (possibly empty) tuple of values.
class
Node
:
public
std
::
enable_shared_from_this
<
Node
>
{
public
:
using
ptr
=
std
::
shared_ptr
<
Node
>
;
protected
:
Node
(
const
Nodes
&
arguments
,
ValueType
::
ptr
type
=
nullptr
);
Node
(
const
Nodes
&
arguments
,
std
::
shared_ptr
<
ValueType
>
value_type
=
nullptr
);
Node
()
:
Node
({},
nullptr
)
{
}
Node
(
std
::
shared_ptr
<
ValueType
>
value_type
)
:
Node
({},
value_type
)
{
}
virtual
~
Node
()
{}
public
:
/// A "one-liner" describing this node.
virtual
std
::
string
description
()
const
=
0
;
...
...
@@ -48,38 +53,48 @@ namespace ngraph
/// Propagate types and check arguments for consistency
virtual
void
propagate_types
()
=
0
;
const
Nodes
&
arguments
()
const
{
return
m_arguments
;
}
const
Nodes
&
get_
arguments
()
const
{
return
m_arguments
;
}
const
std
::
multiset
<
Node
*>&
users
()
const
{
return
m_users
;
}
std
::
string
name
()
const
{
return
m_name
;
}
void
name
(
const
std
::
string
&
name
)
{
m_name
=
name
;
}
std
::
string
get_
name
()
const
{
return
m_name
;
}
void
set_
name
(
const
std
::
string
&
name
)
{
m_name
=
name
;
}
virtual
std
::
string
node_id
()
const
=
0
;
virtual
std
::
string
get_
node_id
()
const
=
0
;
/**
** Return true if this has the same implementing class as node. This
** will be used by the pattern matcher when comparing a pattern
** graph against the graph.
**/
bool
is_same_op_type
(
const
Node
::
ptr
&
node
)
const
/// Return true if this has the same implementing class as node. This
/// will be used by the pattern matcher when comparing a pattern
/// graph against the graph.
bool
is_same_op_type
(
const
std
::
shared_ptr
<
Node
>&
node
)
const
{
return
typeid
(
*
this
)
==
typeid
(
*
node
.
get
());
}
std
::
shared_ptr
<
ValueType
>
get_value_type
()
{
return
m_value_type
;
}
const
std
::
shared_ptr
<
ValueType
>
get_value_type
()
const
{
return
m_value_type
;
}
void
set_value_type
(
const
element
::
Type
&
element_type
,
const
Shape
&
shape
)
{
m_value_type
=
std
::
make_shared
<
TensorViewType
>
(
element_type
,
shape
);
}
void
set_value_type
(
const
std
::
shared_ptr
<
ValueType
>&
value_type
)
{
m_value_type
=
value_type
;
}
bool
is_op
()
const
;
bool
is_parameter
()
const
;
size_t
instance_id
()
const
{
return
m_instance_id
;
}
size_t
get_
instance_id
()
const
{
return
m_instance_id
;
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
,
const
Node
&
);
protected
:
Nodes
m_arguments
;
std
::
multiset
<
Node
*>
m_users
;
std
::
string
m_name
;
size_t
m_instance_id
;
static
size_t
m_next_instance_id
;
Nodes
m_arguments
;
std
::
shared_ptr
<
ValueType
>
m_value_type
;
std
::
multiset
<
Node
*>
m_users
;
std
::
string
m_name
;
size_t
m_instance_id
;
static
size_t
m_next_instance_id
;
};
using
node_ptr
=
std
::
shared_ptr
<
Node
>
;
}
src/ngraph/op.hpp
View file @
973b3a0e
This diff is collapsed.
Click to expand it.
src/ngraph/ops/broadcast.hpp
View file @
973b3a0e
...
...
@@ -19,33 +19,33 @@ namespace ngraph
class
BroadcastOp
:
public
BuiltinOp
{
public
:
using
Axes
=
std
::
vector
<
size_t
>
;
/
**
** /param arg The tensor view to be
broadcast.
** /param shape The shape of the result
** /param broadcast_axes The axis positions (0-based) in the result that are being broadcast.
** the remaining axes in shape must be the same as the shape of arg.
**/
BroadcastOp
(
const
Node
::
ptr
&
arg
,
const
Shape
&
shape
,
const
Axes
&
broadcast_axes
)
///
/// @param arg The tensor view to be broadcast.
/
// @param shape The shape of the result
/// @param broadcast_axes The axis positions (0-based) in the result that are being
broadcast.
/// the remaining axes in shape must be the same as the shape of arg.
///
BroadcastOp
(
const
std
::
shared_ptr
<
Node
>&
arg
,
const
Shape
&
shape
,
const
AxisSet
&
broadcast_axes
)
:
BuiltinOp
({
arg
})
,
m_shape
(
shape
)
,
m_broadcast_axes
(
broadcast_axes
)
{
}
virtual
std
::
string
op_class_name
()
const
override
{
return
"broadcast"
;
}
virtual
std
::
string
get_
op_class_name
()
const
override
{
return
"broadcast"
;
}
virtual
void
propagate_types
()
override
;
protected
:
Shape
m_shape
;
Ax
es
m_broadcast_axes
;
Shape
m_shape
;
Ax
isSet
m_broadcast_axes
;
};
namespace
op
{
Node
::
ptr
broadcast
(
const
Node
::
ptr
&
tensor
,
const
Shape
&
shape
,
const
BroadcastOp
::
Axes
&&
broadcast_axes
);
std
::
shared_ptr
<
Node
>
broadcast
(
const
std
::
shared_ptr
<
Node
>&
tensor
,
const
Shape
&
shape
,
AxisSet
&&
broadcast_axes
);
}
}
src/ngraph/ops/concatenate.hpp
View file @
973b3a0e
...
...
@@ -18,18 +18,18 @@ namespace ngraph
{
namespace
op
{
Node
::
ptr
concatenate
(
const
Nodes
&
args
);
std
::
shared_ptr
<
Node
>
concatenate
(
const
Nodes
&
args
);
}
class
Concat
enate
Op
:
public
BuiltinOp
class
ConcatOp
:
public
BuiltinOp
{
public
:
Concat
enate
Op
(
const
Nodes
&
args
)
ConcatOp
(
const
Nodes
&
args
)
:
BuiltinOp
(
args
)
{
}
virtual
std
::
string
op_class_name
()
const
override
{
return
"concatenate"
;
}
virtual
std
::
string
get_
op_class_name
()
const
override
{
return
"concatenate"
;
}
virtual
void
propagate_types
()
override
;
};
}
src/ngraph/ops/constant.hpp
View file @
973b3a0e
...
...
@@ -21,10 +21,10 @@
namespace
ngraph
{
// Defines methods to all constant scalars
class
ScalarConstantBase
Op
:
public
Node
class
ScalarConstantBase
:
public
Node
{
protected
:
ScalarConstantBase
Op
(
const
std
::
shared_ptr
<
TensorViewType
>&
type
)
ScalarConstantBase
(
const
std
::
shared_ptr
<
TensorViewType
>&
type
)
:
Node
({},
type
)
{
}
...
...
@@ -35,47 +35,39 @@ namespace ngraph
// Implement a constant scalar for each element type.
// The static make method takes a
template
<
typename
T
>
class
ScalarConstant
Op
:
public
ScalarConstantBaseOp
class
ScalarConstant
:
public
ScalarConstantBase
{
public
:
// The ngraph element type
using
element_type
=
T
;
// The C++ type that holds the element type
using
ctype
=
typename
T
::
c
type
;
using
type
=
typename
T
::
type
;
ScalarConstant
Op
(
typename
T
::
c
type
value
)
:
ScalarConstantBase
Op
(
std
::
make_shared
<
TensorViewType
>
(
T
::
type
,
Shape
{}))
ScalarConstant
(
typename
T
::
type
value
)
:
ScalarConstantBase
(
std
::
make_shared
<
TensorViewType
>
(
T
::
element_type
()
,
Shape
{}))
,
m_value
(
value
)
{
}
virtual
std
::
string
description
()
const
override
{
return
"
ConstantScalar
"
;
}
virtual
std
::
string
node_id
()
const
override
virtual
std
::
string
description
()
const
override
{
return
"
ScalarConstant
"
;
}
virtual
std
::
string
get_
node_id
()
const
override
{
std
::
stringstream
ss
;
ss
<<
description
()
<<
"_"
<<
node_id
()
;
ss
<<
description
()
<<
"_"
/* << node_id() */
;
return
ss
.
str
();
}
typename
T
::
ctype
value
()
const
{
return
m_value
;
}
// Make a constant from any value that can be converted to the C++ type we use
// to represent the values.
template
<
typename
U
>
static
std
::
shared_ptr
<
ScalarConstantOp
<
T
>>
make
(
U
value
)
{
return
std
::
make_shared
<
ScalarConstantOp
<
T
>>
(
value
);
}
typename
T
::
type
get_value
()
const
{
return
m_value
;
}
protected
:
typename
T
::
c
type
m_value
;
typename
T
::
type
m_value
;
};
using
Float
ScalarConstantOp
=
ScalarConstantOp
<
element
::
Float
>
;
using
Int8ScalarConstant
Op
=
ScalarConstantOp
<
element
::
Int8
>
;
using
Int32ScalarConstant
Op
=
ScalarConstantOp
<
element
::
Int32
>
;
using
Int64ScalarConstant
Op
=
ScalarConstantOp
<
element
::
Int64
>
;
using
UInt8ScalarConstant
Op
=
ScalarConstantOp
<
element
::
UInt8
>
;
using
UInt32ScalarConstant
Op
=
ScalarConstantOp
<
element
::
UInt32
>
;
using
UInt64ScalarConstant
Op
=
ScalarConstantOp
<
element
::
UInt64
>
;
using
Float
32ScalarConstant
=
ScalarConstant
<
element
::
Float32
>
;
using
Int8ScalarConstant
=
ScalarConstant
<
element
::
Int8
>
;
using
Int32ScalarConstant
=
ScalarConstant
<
element
::
Int32
>
;
using
Int64ScalarConstant
=
ScalarConstant
<
element
::
Int64
>
;
using
UInt8ScalarConstant
=
ScalarConstant
<
element
::
UInt8
>
;
using
UInt32ScalarConstant
=
ScalarConstant
<
element
::
UInt32
>
;
using
UInt64ScalarConstant
=
ScalarConstant
<
element
::
UInt64
>
;
}
src/ngraph/ops/convert.hpp
View file @
973b3a0e
...
...
@@ -16,25 +16,25 @@
namespace
ngraph
{
class
ConvertOp
:
public
BuiltinOp
{
public
:
ConvertOp
(
const
Node
::
ptr
&
arg
,
const
ngraph
::
element
::
Type
&
element_type
)
ConvertOp
(
const
std
::
shared_ptr
<
Node
>
&
arg
,
const
ngraph
::
element
::
Type
&
element_type
)
:
BuiltinOp
({
arg
})
,
m_element_type
(
element_type
)
{
}
virtual
std
::
string
op_class_name
()
const
override
{
return
"convert"
;
}
virtual
std
::
string
get_
op_class_name
()
const
override
{
return
"convert"
;
}
virtual
void
propagate_types
()
override
;
protected
:
const
ngraph
::
element
::
Type
&
m_element_type
;
};
namespace
op
{
std
::
shared_ptr
<
ngraph
::
ConvertOp
>
convert
(
const
Node
::
ptr
&
arg
,
const
ngraph
::
element
::
Type
&
element_type
);
std
::
shared_ptr
<
ngraph
::
ConvertOp
>
convert
(
const
std
::
shared_ptr
<
Node
>&
arg
,
const
ngraph
::
element
::
Type
&
element_type
);
}
}
src/ngraph/ops/dot.hpp
View file @
973b3a0e
...
...
@@ -20,17 +20,18 @@ namespace ngraph
{
public
:
/// TODO: Semantics of arg0 and arg1 axes wrt reduction.
DotOp
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
DotOp
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>
&
arg1
)
:
BuiltinOp
({
arg0
,
arg1
})
{
}
virtual
std
::
string
op_class_name
()
const
override
{
return
"dot"
;
}
virtual
std
::
string
get_
op_class_name
()
const
override
{
return
"dot"
;
}
virtual
void
propagate_types
()
override
;
};
namespace
op
{
Node
::
ptr
dot
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
);
std
::
shared_ptr
<
Node
>
dot
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
);
}
}
src/ngraph/ops/parameter.hpp
View file @
973b3a0e
...
...
@@ -21,11 +21,11 @@ namespace ngraph
{
class
Function
;
/
**
**
Parameters are nodes that represent the arguments that will be passed to user-defined functions.
**
Function creation requires a sequence of parameters.
**
Basic graph operations do not need parameters attached to a function.
**
/
/
//
///
Parameters are nodes that represent the arguments that will be passed to user-defined functions.
///
Function creation requires a sequence of parameters.
///
Basic graph operations do not need parameters attached to a function.
//
/
class
Parameter
:
public
Node
{
friend
class
Function
;
...
...
@@ -36,11 +36,12 @@ namespace ngraph
void
assign_function
(
Function
*
function
,
size_t
index
);
public
:
Parameter
(
const
ValueType
::
ptr
&
value_type
);
Parameter
(
const
std
::
shared_ptr
<
ValueType
>&
value_type
);
Parameter
(
const
ngraph
::
element
::
Type
element_type
,
const
Shape
&
shape
);
std
::
string
description
()
const
override
{
return
"Parameter"
;
}
virtual
void
propagate_types
()
override
;
virtual
std
::
string
node_id
()
const
override
;
virtual
std
::
string
get_
node_id
()
const
override
;
protected
:
Function
*
m_function
;
...
...
@@ -50,9 +51,10 @@ namespace ngraph
namespace
op
{
/// Factory for frameworks
std
::
shared_ptr
<
ngraph
::
Parameter
>
parameter
(
const
ValueType
::
ptr
&
value_type
=
nullptr
);
std
::
shared_ptr
<
ngraph
::
Parameter
>
parameter
(
const
std
::
shared_ptr
<
ValueType
>&
value_type
=
nullptr
);
/// Convenience factory for tests
std
::
shared_ptr
<
ngraph
::
Parameter
>
parameter
(
const
ngraph
::
element
::
Type
element_type
,
const
Shape
&
shape
);
std
::
shared_ptr
<
ngraph
::
Parameter
>
parameter
(
const
element
::
Type
element_type
,
const
Shape
&
shape
);
}
}
src/ngraph/ops/tuple.hpp
View file @
973b3a0e
...
...
@@ -18,7 +18,7 @@ namespace ngraph
{
namespace
op
{
Node
::
ptr
tuple
(
const
Nodes
&
args
);
std
::
shared_ptr
<
Node
>
tuple
(
const
Nodes
&
args
);
}
class
TupleOp
:
public
BuiltinOp
...
...
@@ -29,7 +29,7 @@ namespace ngraph
{
}
virtual
std
::
string
op_class_name
()
const
override
{
return
"tuple"
;
}
virtual
std
::
string
get_
op_class_name
()
const
override
{
return
"tuple"
;
}
virtual
void
propagate_types
()
override
;
};
}
src/ngraph/shape.hpp
View file @
973b3a0e
...
...
@@ -24,9 +24,7 @@ namespace ngraph
class
Shape
{
public
:
/**
** \param sizes A sequence of sizes.
**/
/// @param sizes A sequence of sizes.
Shape
(
const
std
::
initializer_list
<
size_t
>&
sizes
)
:
m_sizes
(
sizes
)
{
...
...
@@ -37,12 +35,11 @@ namespace ngraph
{
}
/**
** Conversion to a vector of sizes.
**/
operator
const
std
::
vector
<
size_t
>&
()
const
{
return
m_sizes
;
}
/// Conversion to a vector of sizes.
operator
const
std
::
vector
<
size_t
>&
()
const
{
return
m_sizes
;
}
bool
operator
==
(
const
Shape
&
shape
)
const
{
return
m_sizes
==
shape
.
m_sizes
;
}
bool
operator
!=
(
const
Shape
&
shape
)
const
{
return
m_sizes
!=
shape
.
m_sizes
;
}
protected
:
std
::
vector
<
size_t
>
m_sizes
;
};
...
...
src/ngraph/topological_sort.cpp
View file @
973b3a0e
...
...
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "node.hpp"
#include "topological_sort.hpp"
#include "node.hpp"
#include "util.hpp"
using
namespace
ngraph
;
...
...
@@ -21,16 +21,16 @@ using namespace std;
void
ngraph
::
TopologicalSort
::
promote_node
(
Node
*
n
)
{
for
(
auto
dn
=
m_dependent_nodes
.
begin
();
dn
!=
m_dependent_nodes
.
end
();
dn
++
)
for
(
auto
dn
=
m_dependent_nodes
.
begin
();
dn
!=
m_dependent_nodes
.
end
();
dn
++
)
{
if
(
dn
->
first
>
0
)
// Skip zero as they should never be promoted
if
(
dn
->
first
>
0
)
// Skip zero as they should never be promoted
{
auto
it
=
find
(
dn
->
second
.
begin
(),
dn
->
second
.
end
(),
n
);
if
(
it
!=
dn
->
second
.
end
())
{
// found the node
dn
->
second
.
erase
(
it
);
m_dependent_nodes
[
dn
->
first
-
1
].
push_back
(
n
);
m_dependent_nodes
[
dn
->
first
-
1
].
push_back
(
n
);
}
}
}
...
...
@@ -38,9 +38,8 @@ void ngraph::TopologicalSort::promote_node(Node* n)
void
ngraph
::
TopologicalSort
::
process
(
node_ptr
p
)
{
traverse_nodes
(
p
,
[
&
](
node_ptr
node
)
{
list
<
Node
*>&
node_list
=
m_dependent_nodes
[
node
->
arguments
().
size
()];
traverse_nodes
(
p
,
[
&
](
node_ptr
node
)
{
list
<
Node
*>&
node_list
=
m_dependent_nodes
[
node
->
get_arguments
().
size
()];
node_list
.
push_back
(
node
.
get
());
});
...
...
src/ngraph/topological_sort.hpp
View file @
973b3a0e
...
...
@@ -14,9 +14,10 @@
#pragma once
#include <memory>
#include <map>
#include <list>
#include <map>
#include <memory>
#include <vector>
namespace
ngraph
{
...
...
@@ -30,7 +31,7 @@ class ngraph::TopologicalSort
public
:
TopologicalSort
()
{}
void
process
(
node_ptr
);
void
process
(
node_ptr
);
const
std
::
vector
<
Node
*>&
get_sorted_list
()
const
;
private
:
...
...
src/ngraph/type.hpp
View file @
973b3a0e
...
...
@@ -25,123 +25,61 @@ namespace ngraph
class
TensorViewType
;
class
TupleType
;
/**
** ValueType is
** TensorViewType
** | TupleType(ValueType[])
**/
/// ValueType is
/// TensorViewType
/// | TupleType(ValueType[])
class
ValueType
{
public
:
/**
** Preferred handle
**/
using
ptr
=
std
::
shared_ptr
<
ValueType
>
;
virtual
~
ValueType
()
{}
virtual
bool
operator
==
(
const
ValueType
::
ptr
&
that
)
const
=
0
;
bool
operator
!=
(
const
ValueType
::
ptr
&
that
)
const
{
return
!
(
*
this
==
that
);
}
virtual
bool
operator
==
(
const
std
::
shared_ptr
<
ValueType
>
&
that
)
const
=
0
;
bool
operator
!=
(
const
std
::
shared_ptr
<
ValueType
>
&
that
)
const
{
return
!
(
*
this
==
that
);
}
};
/**
** Describes a tensor view; an element type and a shape.
**/
/// Describes a tensor view; an element type and a shape.
class
TensorViewType
:
public
ValueType
{
public
:
/**
** Preferred handle
**/
using
ptr
=
std
::
shared_ptr
<
TensorViewType
>
;
/**
** /param element_type The type of the tensor elements.
** /param shape The shape of the tensor.
**/
/// /param element_type The type of the tensor elements.
/// /param shape The shape of the tensor.
TensorViewType
(
const
element
::
Type
&
element_type
,
const
Shape
&
shape
)
:
m_element_type
(
element_type
)
,
m_shape
(
shape
)
{
}
const
element
::
Type
&
element_type
()
const
{
return
m_element_type
;
}
const
Shape
&
shape
()
const
{
return
m_shape
;
}
const
element
::
Type
&
get_
element_type
()
const
{
return
m_element_type
;
}
const
Shape
&
get_
shape
()
const
{
return
m_shape
;
}
virtual
bool
operator
==
(
const
ValueType
::
ptr
&
that
)
const
override
;
virtual
bool
operator
==
(
const
std
::
shared_ptr
<
ValueType
>
&
that
)
const
override
;
protected
:
const
element
::
Type
&
m_element_type
;
Shape
m_shape
;
};
/**
** Describes a tuple of values; a vector of types
**/
/// Describes a tuple of values; a vector of types
class
TupleType
:
public
ValueType
{
public
:
/**
** The preferred handle
**/
using
ptr
=
std
::
shared_ptr
<
ValueType
>
;
/**
** Construct empty tuple and add value types later.
**/
/// Construct empty tuple and add value types later.
TupleType
()
{}
/**
** /param element_types A vector of types for the tuple elements
**/
TupleType
(
const
std
::
vector
<
ValueType
::
ptr
>&
element_types
)
:
m_element_types
(
element_types
)
{
}
const
std
::
vector
<
ValueType
::
ptr
>
element_types
()
const
{
return
m_element_types
;
}
std
::
vector
<
ValueType
::
ptr
>
element_types
()
{
return
m_element_types
;
}
virtual
bool
operator
==
(
const
ValueType
::
ptr
&
that
)
const
override
;
protected
:
std
::
vector
<
ValueType
::
ptr
>
m_element_types
;
};
/**
** Mixin for objects with type information
**/
class
TypedValueMixin
{
public
:
TypedValueMixin
(
const
ValueType
::
ptr
&
type
=
nullptr
)
:
m_type
(
type
)
/// @param element_types A vector of types for the tuple elements
TupleType
(
const
std
::
vector
<
std
::
shared_ptr
<
ValueType
>>&
element_types
)
:
m_element_types
(
element_types
)
{
}
/**
** Set the type
** /param type The new type
**/
void
type
(
const
ValueType
::
ptr
&
type
)
{
m_type
=
type
;
}
/**
** Set the type to be a tensor view type
** /param element_type The type of the tensor elements
** /param shape The shape of the view
**/
void
type
(
const
element
::
Type
&
element_type
,
const
Shape
&
shape
)
const
std
::
vector
<
std
::
shared_ptr
<
ValueType
>>
get_element_types
()
const
{
m_type
=
std
::
make_shared
<
TensorViewType
>
(
element_type
,
shape
)
;
return
m_element_types
;
}
std
::
vector
<
std
::
shared_ptr
<
ValueType
>>
set_element_types
()
{
return
m_element_types
;
}
/**
** The type associated with this value.
**/
ValueType
::
ptr
type
()
{
return
m_type
;
}
/**
** The type associated with this value.
**/
const
ValueType
::
ptr
type
()
const
{
return
m_type
;
}
virtual
bool
operator
==
(
const
std
::
shared_ptr
<
ValueType
>&
that
)
const
override
;
protected
:
ValueType
::
ptr
m_type
;
std
::
vector
<
std
::
shared_ptr
<
ValueType
>>
m_element_types
;
};
}
src/ngraph/visualize.cpp
View file @
973b3a0e
...
...
@@ -12,13 +12,13 @@
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <list>
#include <fstream>
#include <cstdio>
#include <fstream>
#include <list>
#include "visualize.hpp"
#include "ngraph/node.hpp"
#include "util.hpp"
#include "visualize.hpp"
using
namespace
ngraph
;
using
namespace
std
;
...
...
@@ -31,18 +31,17 @@ Visualize::Visualize(const string& name)
void
Visualize
::
add
(
node_ptr
p
)
{
// map<size_t, list<node_ptr>> dependent_nodes;
traverse_nodes
(
p
,
[
&
](
node_ptr
node
)
{
for
(
auto
arg
:
node
->
arguments
())
traverse_nodes
(
p
,
[
&
](
node_ptr
node
)
{
for
(
auto
arg
:
node
->
get_arguments
())
{
m_ss
<<
" "
<<
arg
->
node_id
()
<<
" -> "
<<
node
->
node_id
()
<<
";
\n
"
;
m_ss
<<
" "
<<
arg
->
get_node_id
()
<<
" -> "
<<
node
->
get_
node_id
()
<<
";
\n
"
;
}
});
}
void
Visualize
::
save_dot
(
const
string
&
path
)
const
{
auto
tmp_file
=
path
+
".tmp"
;
auto
tmp_file
=
path
+
".tmp"
;
ofstream
out
(
tmp_file
);
if
(
out
)
{
...
...
@@ -53,7 +52,7 @@ void Visualize::save_dot(const string& path) const
stringstream
ss
;
ss
<<
"dot -Tpng "
<<
tmp_file
<<
" -o "
<<
path
;
auto
cmd
=
ss
.
str
();
auto
cmd
=
ss
.
str
();
auto
stream
=
popen
(
cmd
.
c_str
(),
"r"
);
pclose
(
stream
);
...
...
src/ops/broadcast.cpp
View file @
973b3a0e
...
...
@@ -17,22 +17,20 @@
using
namespace
std
;
using
namespace
ngraph
;
/**
** /param arg The tensor view to be broadcast.
** /param shape The shape of the result
** /param broadcast_axes The axis positions (0-based) in the result that are being broadcast.
** the remaining axes in shape must be the same as the shape of arg.
**/
Node
::
ptr
ngraph
::
op
::
broadcast
(
const
Node
::
ptr
&
tensor
,
const
Shape
&
shape
,
const
BroadcastOp
::
Axes
&&
broadcast_axes
)
/// @param tensor The tensor view to be broadcast.
/// @param shape The shape of the result
/// @param broadcast_axes The axis positions (0-based) in the result that are being broadcast.
/// the remaining axes in shape must be the same as the shape of arg.
std
::
shared_ptr
<
Node
>
ngraph
::
op
::
broadcast
(
const
std
::
shared_ptr
<
Node
>&
tensor
,
const
Shape
&
shape
,
AxisSet
&&
broadcast_axes
)
{
return
make_shared
<
BroadcastOp
>
(
tensor
,
shape
,
broadcast_axes
);
}
void
BroadcastOp
::
propagate_types
()
{
auto
arg_type
=
m_arguments
.
at
(
0
)
->
type
();
auto
arg_type
=
m_arguments
.
at
(
0
)
->
get_value_
type
();
if
(
nullptr
==
arg_type
)
{
throw
ngraph_error
(
"Argument to broadcast is missing type."
);
...
...
@@ -47,11 +45,11 @@ void BroadcastOp::propagate_types()
{
target_shape
.
erase
(
target_shape
.
begin
()
+
*
i
);
}
if
(
Shape
{
target_shape
}
!=
arg_tensor_view_type
->
shape
())
if
(
Shape
{
target_shape
}
!=
arg_tensor_view_type
->
get_
shape
())
{
throw
ngraph_error
(
"Broadcast arg, shape, and axes are incompatible"
);
}
// TODO If m_type is already set (by framework), this should verify that the type
// we expect is consistent with the type the framework expects.
m_
type
=
make_shared
<
TensorViewType
>
(
arg_tensor_view_type
->
element_type
(),
m_shape
);
m_
value_type
=
make_shared
<
TensorViewType
>
(
arg_tensor_view_type
->
get_
element_type
(),
m_shape
);
}
src/ops/concatenate.cpp
View file @
973b3a0e
...
...
@@ -19,12 +19,12 @@
using
namespace
std
;
using
namespace
ngraph
;
void
Concat
enate
Op
::
propagate_types
()
void
ConcatOp
::
propagate_types
()
{
throw
ngraph_error
(
"NIY"
);
}
Node
::
ptr
op
::
concatenate
(
const
std
::
vector
<
Node
::
ptr
>&
args
)
std
::
shared_ptr
<
Node
>
op
::
concatenate
(
const
std
::
vector
<
std
::
shared_ptr
<
Node
>
>&
args
)
{
return
make_shared
<
Concat
enate
Op
>
(
args
);
return
make_shared
<
ConcatOp
>
(
args
);
}
src/ops/constant.cpp
View file @
973b3a0e
...
...
@@ -16,4 +16,4 @@
using
namespace
ngraph
;
void
ScalarConstantBase
Op
::
propagate_types
()
{}
void
ScalarConstantBase
::
propagate_types
()
{}
src/ops/convert.cpp
View file @
973b3a0e
...
...
@@ -24,7 +24,8 @@ void ConvertOp::propagate_types()
throw
ngraph_error
(
"NIY"
);
}
shared_ptr
<
ConvertOp
>
op
::
convert
(
const
Node
::
ptr
&
arg
,
const
element
::
Type
&
element_type
)
shared_ptr
<
ConvertOp
>
op
::
convert
(
const
std
::
shared_ptr
<
Node
>&
arg
,
const
element
::
Type
&
element_type
)
{
return
make_shared
<
ConvertOp
>
(
arg
,
element_type
);
}
src/ops/dot.cpp
View file @
973b3a0e
...
...
@@ -20,28 +20,31 @@ using namespace std;
using
namespace
ngraph
;
/// TODO: Semantics of arg0 and arg1 axes wrt reduction.
Node
::
ptr
ngraph
::
op
::
dot
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
std
::
shared_ptr
<
Node
>
ngraph
::
op
::
dot
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
{
return
make_shared
<
DotOp
>
(
arg0
,
arg1
);
}
void
DotOp
::
propagate_types
()
{
auto
arg0_tensor_type
=
dynamic_pointer_cast
<
TensorViewType
>
(
m_arguments
.
at
(
0
)
->
type
());
auto
arg1_tensor_type
=
dynamic_pointer_cast
<
TensorViewType
>
(
m_arguments
.
at
(
1
)
->
type
());
auto
arg0_tensor_type
=
dynamic_pointer_cast
<
TensorViewType
>
(
m_arguments
.
at
(
0
)
->
get_value_type
());
auto
arg1_tensor_type
=
dynamic_pointer_cast
<
TensorViewType
>
(
m_arguments
.
at
(
1
)
->
get_value_type
());
if
(
nullptr
==
arg0_tensor_type
||
nullptr
==
arg1_tensor_type
)
{
throw
ngraph_error
(
"Arguments to dot must be tensor views"
);
}
if
(
arg0_tensor_type
->
element_type
()
!=
arg1_tensor_type
->
element_type
())
if
(
arg0_tensor_type
->
get_element_type
()
!=
arg1_tensor_type
->
get_
element_type
())
{
throw
ngraph_error
(
"Arguments to dot must have the same element type"
);
}
// Use NumPy semantics for now
// Last axis of first arg reduces against second to last of second arg if more than one axis, else axis.
vector
<
size_t
>
arg0_shape
=
arg0_tensor_type
->
shape
();
vector
<
size_t
>
arg1_shape
=
arg1_tensor_type
->
shape
();
vector
<
size_t
>
arg0_shape
=
arg0_tensor_type
->
get_
shape
();
vector
<
size_t
>
arg1_shape
=
arg1_tensor_type
->
get_
shape
();
size_t
arg0_reduction
=
arg0_shape
.
size
()
-
1
;
size_t
arg1_reduction
;
if
(
arg1_shape
.
size
()
>
1
)
...
...
@@ -60,5 +63,5 @@ void DotOp::propagate_types()
copy
(
arg0_shape
.
begin
(),
arg0_shape
.
begin
()
+
arg1_reduction
,
result_shape
.
end
());
copy
(
arg1_shape
.
begin
(),
arg1_shape
.
begin
()
+
arg1_reduction
,
result_shape
.
end
());
copy
(
arg1_shape
.
begin
()
+
arg1_reduction
,
arg1_shape
.
end
(),
result_shape
.
end
());
m_
type
=
make_shared
<
TensorViewType
>
(
arg0_tensor_type
->
element_type
(),
result_shape
);
m_
value_type
=
make_shared
<
TensorViewType
>
(
arg0_tensor_type
->
get_
element_type
(),
result_shape
);
}
src/ops/function.cpp
View file @
973b3a0e
...
...
@@ -17,7 +17,7 @@
using
namespace
std
;
using
namespace
ngraph
;
Function
::
Function
(
const
Node
::
ptr
&
result
,
Function
::
Function
(
const
std
::
shared_ptr
<
Node
>&
result
,
const
std
::
vector
<
std
::
shared_ptr
<
ngraph
::
Parameter
>>&
parameters
)
:
m_result
(
result
)
,
m_parameters
(
parameters
)
...
...
@@ -30,13 +30,13 @@ Function::Function(const Node::ptr& result
}
}
shared_ptr
<
Function
>
ngraph
::
op
::
function
(
const
Node
::
ptr
&
result
,
shared_ptr
<
Function
>
ngraph
::
op
::
function
(
const
std
::
shared_ptr
<
Node
>&
result
,
const
initializer_list
<
shared_ptr
<
Parameter
>>&
parameters
)
{
return
make_shared
<
Function
>
(
result
,
parameters
);
}
shared_ptr
<
Function
>
ngraph
::
op
::
function
(
const
Node
::
ptr
&
result
,
shared_ptr
<
Function
>
ngraph
::
op
::
function
(
const
std
::
shared_ptr
<
Node
>&
result
,
const
vector
<
shared_ptr
<
Parameter
>>&
parameters
)
{
return
make_shared
<
Function
>
(
result
,
parameters
);
...
...
src/ops/op.cpp
View file @
973b3a0e
...
...
@@ -20,24 +20,26 @@
using
namespace
ngraph
;
using
namespace
std
;
std
::
string
ngraph
::
Op
::
node_id
()
const
std
::
string
ngraph
::
Op
::
get_
node_id
()
const
{
stringstream
ss
;
ss
<<
op_class_name
()
<<
"_"
<<
m_instance_id
;
ss
<<
get_
op_class_name
()
<<
"_"
<<
m_instance_id
;
return
ss
.
str
();
}
Node
::
ptr
ngraph
::
op
::
abs
(
const
Node
::
ptr
&
arg
)
std
::
shared_ptr
<
Node
>
ngraph
::
op
::
abs
(
const
std
::
shared_ptr
<
Node
>
&
arg
)
{
return
make_shared
<
AbsOp
>
(
arg
);
}
Node
::
ptr
ngraph
::
op
::
add
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
std
::
shared_ptr
<
Node
>
ngraph
::
op
::
add
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
{
return
make_shared
<
AddOp
>
(
arg0
,
arg1
);
}
Node
::
ptr
ngraph
::
op
::
ceiling
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
std
::
shared_ptr
<
Node
>
ngraph
::
op
::
ceiling
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
{
return
make_shared
<
CeilingOp
>
(
arg0
,
arg1
);
}
...
...
@@ -45,61 +47,68 @@ Node::ptr ngraph::op::ceiling(const Node::ptr& arg0, const Node::ptr& arg1)
// 'convert',
// 'convolution',
Node
::
ptr
ngraph
::
op
::
divide
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
std
::
shared_ptr
<
Node
>
ngraph
::
op
::
divide
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
{
return
make_shared
<
DivideOp
>
(
arg0
,
arg1
);
}
Node
::
ptr
ngraph
::
op
::
exponential
(
const
Node
::
ptr
&
arg0
)
std
::
shared_ptr
<
Node
>
ngraph
::
op
::
exp
(
const
std
::
shared_ptr
<
Node
>
&
arg0
)
{
return
make_shared
<
Exp
onential
Op
>
(
arg0
);
return
make_shared
<
ExpOp
>
(
arg0
);
}
Node
::
ptr
ngraph
::
op
::
floor
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
std
::
shared_ptr
<
Node
>
ngraph
::
op
::
floor
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
{
return
make_shared
<
FloorOp
>
(
arg0
,
arg1
);
}
Node
::
ptr
ngraph
::
op
::
log
(
const
Node
::
ptr
&
arg0
)
std
::
shared_ptr
<
Node
>
ngraph
::
op
::
log
(
const
std
::
shared_ptr
<
Node
>
&
arg0
)
{
return
make_shared
<
LogOp
>
(
arg0
);
}
Node
::
ptr
ngraph
::
op
::
maximum
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
std
::
shared_ptr
<
Node
>
ngraph
::
op
::
maximum
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
{
return
make_shared
<
MaximumOp
>
(
arg0
,
arg1
);
}
Node
::
ptr
ngraph
::
op
::
minimum
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
std
::
shared_ptr
<
Node
>
ngraph
::
op
::
minimum
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
{
return
make_shared
<
MinimumOp
>
(
arg0
,
arg1
);
}
Node
::
ptr
ngraph
::
op
::
multiply
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
std
::
shared_ptr
<
Node
>
ngraph
::
op
::
multiply
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
{
return
make_shared
<
MultiplyOp
>
(
arg0
,
arg1
);
}
Node
::
ptr
ngraph
::
op
::
negate
(
const
Node
::
ptr
&
arg0
)
std
::
shared_ptr
<
Node
>
ngraph
::
op
::
negative
(
const
std
::
shared_ptr
<
Node
>
&
arg0
)
{
return
make_shared
<
NegateOp
>
(
arg0
);
return
make_shared
<
Negat
iv
eOp
>
(
arg0
);
}
// 'pad',
Node
::
ptr
ngraph
::
op
::
power
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
std
::
shared_ptr
<
Node
>
ngraph
::
op
::
power
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
{
return
make_shared
<
PowerOp
>
(
arg0
,
arg1
);
}
//'reduce',
Node
::
ptr
ngraph
::
op
::
remainder
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
std
::
shared_ptr
<
Node
>
ngraph
::
op
::
remainder
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
{
return
make_shared
<
RemainderOp
>
(
arg0
,
arg1
);
}
Node
::
ptr
ngraph
::
op
::
reshape
(
const
Node
::
ptr
&
arg0
,
const
Shape
&
shape
)
std
::
shared_ptr
<
Node
>
ngraph
::
op
::
reshape
(
const
std
::
shared_ptr
<
Node
>
&
arg0
,
const
Shape
&
shape
)
{
return
make_shared
<
ReshapeOp
>
(
arg0
,
shape
);
}
...
...
@@ -109,7 +118,8 @@ Node::ptr ngraph::op::reshape(const Node::ptr& arg0, const Shape& shape)
// 'select',
//'slice',
Node
::
ptr
ngraph
::
op
::
subtract
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
std
::
shared_ptr
<
Node
>
ngraph
::
op
::
subtract
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
{
return
make_shared
<
SubtractOp
>
(
arg0
,
arg1
);
}
...
...
src/ops/parameter.cpp
View file @
973b3a0e
...
...
@@ -19,13 +19,18 @@
using
namespace
std
;
using
namespace
ngraph
;
Parameter
::
Parameter
(
const
ValueType
::
ptr
&
value_type
)
:
Node
(
{},
value_type
)
Parameter
::
Parameter
(
const
std
::
shared_ptr
<
ValueType
>
&
value_type
)
:
Node
(
value_type
)
,
m_function
(
nullptr
)
,
m_index
(
0
)
{
}
Parameter
::
Parameter
(
const
ngraph
::
element
::
Type
element_type
,
const
Shape
&
shape
)
:
Parameter
(
make_shared
<
TensorViewType
>
(
element_type
,
shape
))
{
}
void
Parameter
::
assign_function
(
Function
*
function
,
size_t
index
)
{
if
(
nullptr
!=
m_function
)
...
...
@@ -36,11 +41,9 @@ void Parameter::assign_function(Function* function, size_t index)
m_index
=
index
;
}
void
Parameter
::
propagate_types
()
{
}
void
Parameter
::
propagate_types
()
{}
shared_ptr
<
Parameter
>
ngraph
::
op
::
parameter
(
const
ValueType
::
ptr
&
value_type
)
shared_ptr
<
Parameter
>
ngraph
::
op
::
parameter
(
const
std
::
shared_ptr
<
ValueType
>
&
value_type
)
{
return
make_shared
<
Parameter
>
(
value_type
);
}
...
...
@@ -51,7 +54,7 @@ shared_ptr<Parameter> ngraph::op::parameter(const ngraph::element::Type element_
return
make_shared
<
Parameter
>
(
make_shared
<
TensorViewType
>
(
element_type
,
shape
));
}
std
::
string
ngraph
::
Parameter
::
node_id
()
const
std
::
string
ngraph
::
Parameter
::
get_
node_id
()
const
{
stringstream
ss
;
ss
<<
"parameter_"
<<
m_instance_id
;
...
...
src/ops/tuple.cpp
View file @
973b3a0e
...
...
@@ -24,7 +24,7 @@ void TupleOp::propagate_types()
throw
ngraph_error
(
"NIY"
);
}
Node
::
ptr
op
::
tuple
(
const
std
::
vector
<
Node
::
ptr
>&
args
)
std
::
shared_ptr
<
Node
>
op
::
tuple
(
const
std
::
vector
<
std
::
shared_ptr
<
Node
>
>&
args
)
{
return
make_shared
<
TupleOp
>
(
args
);
}
src/types/element_type.cpp
View file @
973b3a0e
...
...
@@ -48,28 +48,3 @@ size_t ngraph::element::Type::size() const
{
return
std
::
ceil
((
float
)
m_bitwidth
/
8.0
);
}
namespace
{
const
element
::
Float
s_float32_t
=
element
::
Float
{
"float"
};
const
element
::
Int8
s_int8_t
=
element
::
Int8
{
"int8_t"
};
const
element
::
Int32
s_int32_t
=
element
::
Int32
{
"int32_t"
};
const
element
::
Int64
s_int64_t
=
element
::
Int64
{
"int64_t"
};
const
element
::
UInt8
s_uint8_t
=
element
::
UInt8
{
"uint8_t"
};
const
element
::
UInt32
s_uint32_t
=
element
::
UInt32
{
"uint32_t"
};
const
element
::
UInt64
s_uint64_t
=
element
::
UInt64
{
"uint64_t"
};
}
template
<>
const
element
::
TraitedType
<
float
>&
element
::
TraitedType
<
float
>::
type
=
s_float32_t
;
template
<>
const
element
::
TraitedType
<
int8_t
>&
element
::
TraitedType
<
int8_t
>::
type
=
s_int8_t
;
template
<>
const
element
::
TraitedType
<
int32_t
>&
element
::
TraitedType
<
int32_t
>::
type
=
s_int32_t
;
template
<>
const
element
::
TraitedType
<
int64_t
>&
element
::
TraitedType
<
int64_t
>::
type
=
s_int64_t
;
template
<>
const
element
::
TraitedType
<
uint8_t
>&
element
::
TraitedType
<
uint8_t
>::
type
=
s_uint8_t
;
template
<>
const
element
::
TraitedType
<
uint32_t
>&
element
::
TraitedType
<
uint32_t
>::
type
=
s_uint32_t
;
template
<>
const
element
::
TraitedType
<
uint64_t
>&
element
::
TraitedType
<
uint64_t
>::
type
=
s_uint64_t
;
\ No newline at end of file
src/types/type.cpp
View file @
973b3a0e
...
...
@@ -19,30 +19,30 @@
using
namespace
std
;
using
namespace
ngraph
;
bool
TensorViewType
::
operator
==
(
const
ValueType
::
ptr
&
that
)
const
bool
TensorViewType
::
operator
==
(
const
std
::
shared_ptr
<
ValueType
>
&
that
)
const
{
auto
that_tvt
=
dynamic_pointer_cast
<
TensorViewType
>
(
that
);
if
(
nullptr
==
that_tvt
)
{
return
false
;
}
if
(
that_tvt
->
element_type
()
!=
m_element_type
)
if
(
that_tvt
->
get_
element_type
()
!=
m_element_type
)
{
return
false
;
}
if
(
that_tvt
->
shape
()
!=
m_shape
)
if
(
that_tvt
->
get_
shape
()
!=
m_shape
)
{
return
false
;
}
return
true
;
}
bool
TupleType
::
operator
==
(
const
ValueType
::
ptr
&
that
)
const
bool
TupleType
::
operator
==
(
const
std
::
shared_ptr
<
ValueType
>
&
that
)
const
{
auto
that_tvt
=
dynamic_pointer_cast
<
TupleType
>
(
that
);
if
(
nullptr
==
that_tvt
)
{
return
false
;
}
return
that_tvt
->
element_types
()
==
element_types
();
return
that_tvt
->
get_element_types
()
==
get_
element_types
();
}
src/util.cpp
View file @
973b3a0e
...
...
@@ -26,8 +26,8 @@ void ngraph::dump(ostream& out, const void* _data, size_t _size)
{
auto
flags
=
out
.
flags
();
const
uint8_t
*
data
=
reinterpret_cast
<
const
uint8_t
*>
(
_data
);
in
t
len
=
_size
;
in
t
index
=
0
;
size_
t
len
=
_size
;
size_
t
index
=
0
;
while
(
index
<
len
)
{
out
<<
std
::
hex
<<
std
::
setw
(
8
)
<<
std
::
setfill
(
'0'
)
<<
index
;
...
...
@@ -136,11 +136,11 @@ static void traverse_nodes(std::shared_ptr<ngraph::Node> p,
std
::
set
<
size_t
>&
instances_seen
)
{
f
(
p
);
for
(
auto
arg
:
p
->
arguments
())
for
(
auto
arg
:
p
->
get_
arguments
())
{
if
(
instances_seen
.
find
(
arg
->
instance_id
())
==
instances_seen
.
end
())
if
(
instances_seen
.
find
(
arg
->
get_
instance_id
())
==
instances_seen
.
end
())
{
instances_seen
.
insert
(
arg
->
instance_id
());
instances_seen
.
insert
(
arg
->
get_
instance_id
());
traverse_nodes
(
arg
,
f
,
instances_seen
);
}
}
...
...
test/build_graph.cpp
View file @
973b3a0e
...
...
@@ -16,40 +16,41 @@
#include "ngraph/ngraph.hpp"
#include <memory>
using
namespace
std
;
using
namespace
ngraph
;
TEST
(
build_graph
,
build_simple
)
{
// Function with 4 parameters
auto
arg0
=
op
::
parameter
(
element
::
Float
::
type
,
{
7
,
3
});
auto
arg1
=
op
::
parameter
(
element
::
Float
::
type
,
{
3
});
auto
arg2
=
op
::
parameter
(
element
::
Float
::
type
,
{
32
,
7
});
auto
arg3
=
op
::
parameter
(
element
::
Float
::
type
,
{
32
,
7
});
auto
broadcast_1
=
op
::
broadcast
(
arg3
,
{
10
,
32
,
7
},
{
0
});
auto
dot
=
op
::
dot
(
arg2
,
arg0
);
ASSERT_EQ
(
2
,
dot
->
arguments
().
size
()
);
ASSERT_EQ
(
dot
->
arguments
()[
0
],
arg2
);
ASSERT_EQ
(
dot
->
arguments
()[
1
],
arg0
);
auto
arg0
=
node
<
Parameter
>
(
element
::
Float32
::
element_type
(),
Shape
{
7
,
3
});
auto
arg1
=
node
<
Parameter
>
(
element
::
Float32
::
element_type
(),
Shape
{
3
});
auto
arg2
=
node
<
Parameter
>
(
element
::
Float32
::
element_type
(),
Shape
{
32
,
7
});
auto
arg3
=
node
<
Parameter
>
(
element
::
Float32
::
element_type
(),
Shape
{
32
,
7
});
auto
broadcast_1
=
node
<
BroadcastOp
>
(
arg3
,
Shape
{
10
,
32
,
7
},
AxisSet
{
0
});
auto
b1
=
node
<
BroadcastOp
>
(
arg3
,
Shape
{
10
,
32
,
7
},
AxisSet
{
0
}
);
auto
dot
=
node
<
DotOp
>
(
arg2
,
arg0
);
ASSERT_EQ
(
dot
->
get_
arguments
()[
0
],
arg2
);
ASSERT_EQ
(
dot
->
get_
arguments
()[
1
],
arg0
);
auto
cluster_0
=
op
::
function
(
dot
,
{
arg0
,
arg1
,
arg2
,
arg3
});
ASSERT_EQ
(
cluster_0
->
result
(),
dot
);
ASSERT_EQ
(
cluster_0
->
get_
result
(),
dot
);
}
// Check upcasting from ValueType.
TEST
(
build_graph
,
as_type
)
{
// Check upcasting a ValueType::ptr that is a TensorViewType to a TensorViewType and Tuple.
ValueType
::
ptr
tv_vt
=
make_shared
<
TensorViewType
>
(
element
::
Float
::
type
,
Shape
{
2
,
3
,
5
});
auto
tv_tv
=
dynamic_pointer_cast
<
TensorViewType
>
(
tv_vt
);
auto
tv_vt
=
make_shared
<
TensorViewType
>
(
element
::
Float32
::
element_type
()
,
Shape
{
2
,
3
,
5
});
auto
tv_tv
=
dynamic_pointer_cast
<
TensorViewType
>
(
tv_vt
);
ASSERT_EQ
(
tv_vt
,
tv_tv
);
auto
tv_tp
=
dynamic_pointer_cast
<
TupleType
>
(
tv_vt
);
ASSERT_EQ
(
nullptr
,
tv_tp
);
// Check upcasting a ValueType::ptr that is a TupleType to a TensorViewType and Tuple.
ValueType
::
ptr
tp_vt
=
make_shared
<
TupleType
>
(
vector
<
ValueType
::
ptr
>
{
tv_vt
,
tv_vt
});
auto
tp_tv
=
dynamic_pointer_cast
<
TensorViewType
>
(
tp_vt
);
auto
tp_vt
=
make_shared
<
TupleType
>
(
ValueTypes
{
tv_vt
,
tv_vt
});
auto
tp_tv
=
dynamic_pointer_cast
<
TensorViewType
>
(
tp_vt
);
ASSERT_EQ
(
nullptr
,
tp_tv
);
auto
tp_tp
=
dynamic_pointer_cast
<
TupleType
>
(
tp_vt
);
ASSERT_EQ
(
tp_vt
,
tp_tp
);
...
...
@@ -58,15 +59,15 @@ TEST(build_graph, as_type)
// Check node comparisons
TEST
(
build_graph
,
node_comparison
)
{
auto
arg0
=
op
::
parameter
(
element
::
Float
::
type
,
{
32
,
3
});
auto
arg1
=
op
::
parameter
(
element
::
Float
::
type
,
{
3
});
auto
arg2
=
op
::
parameter
(
element
::
Float
::
type
,
{
32
});
auto
arg0
=
node
<
Parameter
>
(
element
::
Float32
::
element_type
(),
Shape
{
32
,
3
});
auto
arg1
=
node
<
Parameter
>
(
element
::
Float32
::
element_type
(),
Shape
{
3
});
auto
arg2
=
node
<
Parameter
>
(
element
::
Float32
::
element_type
(),
Shape
{
32
});
auto
dot
=
op
::
dot
(
arg0
,
arg1
);
auto
add
=
op
::
add
(
dot
,
arg2
);
auto
parg
=
op
::
parameter
(
element
::
Float
::
type
,
{});
auto
pattern_dot
=
op
::
dot
(
parg
,
parg
);
auto
parg
=
node
<
Parameter
>
(
element
::
Float32
::
element_type
(),
Shape
{});
auto
pattern_dot
=
node
<
DotOp
>
(
parg
,
parg
);
ASSERT_TRUE
(
pattern_dot
->
is_same_op_type
(
dot
));
// TODO This passes because typeid is not behaving as documented.
// Need to figure out what's wrong.
...
...
@@ -76,27 +77,26 @@ TEST(build_graph, node_comparison)
TEST
(
build_graph
,
literal
)
{
// float scalar from a float
auto
float0
=
FloatScalarConstantOp
::
make
(
3.0
);
auto
float_scalar_type
=
make_shared
<
TensorViewType
>
(
element
::
Float
::
type
,
Shape
{});
ASSERT_EQ
(
float0
->
value
(),
3.0
);
ASSERT_EQ
(
*
float0
->
type
(),
float_scalar_type
);
auto
d
=
op
::
dot
(
float0
,
float0
);
ASSERT_EQ
(
d
->
arguments
().
at
(
0
),
float0
);
ASSERT_EQ
(
d
->
arguments
().
at
(
1
),
float0
);
//auto float0 = FloatScalarConstant::make(3.0);
auto
float0
=
node
<
Float32ScalarConstant
>
(
3.0
);
auto
float_scalar_type
=
make_shared
<
TensorViewType
>
(
element
::
Float32
::
element_type
(),
Shape
{});
ASSERT_EQ
(
float0
->
get_value
(),
3.0
);
ASSERT_EQ
(
*
float0
->
get_value_type
(),
float_scalar_type
);
auto
d
=
node
<
DotOp
>
(
float0
,
float0
);
ASSERT_EQ
(
d
->
get_arguments
().
at
(
0
),
float0
);
ASSERT_EQ
(
d
->
get_arguments
().
at
(
1
),
float0
);
// float scalar from an int
auto
float1
=
FloatScalarConstantOp
::
make
(
3
);
ASSERT_EQ
(
float1
->
value
(),
3
);
ASSERT_EQ
(
*
float1
->
type
(),
float_scalar_type
);
auto
int32_0
=
Int32ScalarConstantOp
::
make
(
3.0
);
auto
int32_scalar_type
=
make_shared
<
TensorViewType
>
(
element
::
Int32
::
type
,
Shape
{});
ASSERT_EQ
(
int32_0
->
value
(),
3
);
ASSERT_EQ
(
*
int32_0
->
type
(),
int32_scalar_type
);
ASSERT_NE
(
*
int32_0
->
type
(),
float_scalar_type
);
auto
float1
=
node
<
Float32ScalarConstant
>
(
3
);
ASSERT_EQ
(
float1
->
get_
value
(),
3
);
ASSERT_EQ
(
*
float1
->
get_value_
type
(),
float_scalar_type
);
auto
int32_0
=
node
<
Int32ScalarConstant
>
(
3.0
);
auto
int32_scalar_type
=
make_shared
<
TensorViewType
>
(
element
::
Int32
::
element_type
()
,
Shape
{});
ASSERT_EQ
(
int32_0
->
get_
value
(),
3
);
ASSERT_EQ
(
*
int32_0
->
get_value_
type
(),
int32_scalar_type
);
ASSERT_NE
(
*
int32_0
->
get_value_
type
(),
float_scalar_type
);
}
// Check argument inverses
TEST
(
build_graph
,
arg_inverse
)
{
}
TEST
(
build_graph
,
arg_inverse
)
{}
test/op.cpp
View file @
973b3a0e
...
...
@@ -23,7 +23,7 @@ using namespace ngraph;
TEST
(
op
,
is_op
)
{
auto
arg0
=
op
::
parameter
(
element
::
Float
::
type
,
{
1
});
auto
arg0
=
op
::
parameter
(
element
::
Float
32
::
element_type
()
,
{
1
});
ASSERT_NE
(
nullptr
,
arg0
);
EXPECT_TRUE
(
arg0
->
is_parameter
());
EXPECT_FALSE
(
arg0
->
is_op
());
...
...
@@ -31,7 +31,7 @@ TEST(op, is_op)
TEST
(
op
,
is_parameter
)
{
auto
arg0
=
op
::
parameter
(
element
::
Float
::
type
,
{
1
});
auto
arg0
=
op
::
parameter
(
element
::
Float
32
::
element_type
()
,
{
1
});
ASSERT_NE
(
nullptr
,
arg0
);
auto
t0
=
op
::
add
(
arg0
,
arg0
);
ASSERT_NE
(
nullptr
,
t0
);
...
...
test/topological_sort.cpp
View file @
973b3a0e
...
...
@@ -29,21 +29,20 @@ using namespace ngraph;
static
bool
validate_list
(
const
vector
<
Node
*>&
nodes
)
{
bool
rc
=
true
;
for
(
auto
it
=
nodes
.
rbegin
();
it
!=
nodes
.
rend
();
it
++
)
for
(
auto
it
=
nodes
.
rbegin
();
it
!=
nodes
.
rend
();
it
++
)
{
Node
*
node
=
*
it
;
auto
node_tmp
=
*
it
;
auto
dependencies_tmp
=
node_tmp
->
arguments
();
auto
node_tmp
=
*
it
;
auto
dependencies_tmp
=
node_tmp
->
get_arguments
();
vector
<
Node
*>
dependencies
;
for
(
shared_ptr
<
Node
>
n
:
dependencies_tmp
)
{
dependencies
.
push_back
(
n
.
get
());
}
auto
tmp
=
it
+
1
;
for
(;
tmp
!=
nodes
.
rend
();
tmp
++
)
auto
tmp
=
it
+
1
;
for
(;
tmp
!=
nodes
.
rend
();
tmp
++
)
{
auto
dep_tmp
=
*
tmp
;
auto
found
=
find
(
dependencies
.
begin
(),
dependencies
.
end
(),
dep_tmp
);
auto
found
=
find
(
dependencies
.
begin
(),
dependencies
.
end
(),
dep_tmp
);
if
(
found
!=
dependencies
.
end
())
{
dependencies
.
erase
(
found
);
...
...
@@ -60,9 +59,9 @@ static bool validate_list(const vector<Node*>& nodes)
TEST
(
topological_sort
,
basic
)
{
vector
<
shared_ptr
<
Parameter
>>
args
;
for
(
int
i
=
0
;
i
<
10
;
i
++
)
for
(
int
i
=
0
;
i
<
10
;
i
++
)
{
auto
arg
=
op
::
parameter
(
element
::
Float
::
type
,
{
1
});
auto
arg
=
op
::
parameter
(
element
::
Float
32
::
element_type
()
,
{
1
});
ASSERT_NE
(
nullptr
,
arg
);
args
.
push_back
(
arg
);
}
...
...
@@ -79,13 +78,13 @@ TEST(topological_sort, basic)
auto
t4
=
op
::
add
(
t2
,
args
[
5
]);
ASSERT_NE
(
nullptr
,
t3
);
Node
::
ptr
r0
=
op
::
add
(
t3
,
t4
);
auto
r0
=
op
::
add
(
t3
,
t4
);
ASSERT_NE
(
nullptr
,
r0
);
auto
f0
=
op
::
function
(
r0
,
args
);
ASSERT_NE
(
nullptr
,
f0
);
ASSERT_EQ
(
2
,
r0
->
arguments
().
size
());
ASSERT_EQ
(
2
,
r0
->
get_
arguments
().
size
());
auto
op_r0
=
static_pointer_cast
<
Op
>
(
r0
);
Visualize
vz
;
...
...
test/util.cpp
View file @
973b3a0e
...
...
@@ -134,9 +134,7 @@ TEST(util, contains)
EXPECT_FALSE
(
contains
(
v1
,
8
));
}
TEST
(
util
,
remove_from
)
{
}
TEST
(
util
,
remove_from
)
{}
TEST
(
util
,
reduce
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment