Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
064fb0fc
Commit
064fb0fc
authored
Sep 01, 2017
by
Scott Cyphers
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
formatting
parent
fc0455ba
Hide whitespace changes
Inline
Side-by-side
Showing
25 changed files
with
168 additions
and
127 deletions
+168
-127
common.hpp
src/ngraph/common.hpp
+4
-4
element_type.hpp
src/ngraph/element_type.hpp
+27
-21
function.hpp
src/ngraph/function.hpp
+6
-5
node.cpp
src/ngraph/node.cpp
+2
-1
node.hpp
src/ngraph/node.hpp
+6
-6
op.hpp
src/ngraph/op.hpp
+28
-14
broadcast.hpp
src/ngraph/ops/broadcast.hpp
+7
-5
constant.hpp
src/ngraph/ops/constant.hpp
+1
-1
convert.hpp
src/ngraph/ops/convert.hpp
+3
-3
dot.hpp
src/ngraph/ops/dot.hpp
+2
-1
parameter.hpp
src/ngraph/ops/parameter.hpp
+3
-2
shape.hpp
src/ngraph/shape.hpp
+2
-1
topological_sort.cpp
src/ngraph/topological_sort.cpp
+5
-6
topological_sort.hpp
src/ngraph/topological_sort.hpp
+4
-3
type.hpp
src/ngraph/type.hpp
+11
-4
visualize.cpp
src/ngraph/visualize.cpp
+6
-7
broadcast.cpp
src/ops/broadcast.cpp
+3
-3
convert.cpp
src/ops/convert.cpp
+2
-1
dot.cpp
src/ops/dot.cpp
+6
-3
function.cpp
src/ops/function.cpp
+3
-3
op.cpp
src/ops/op.cpp
+20
-10
parameter.cpp
src/ops/parameter.cpp
+1
-3
build_graph.cpp
test/build_graph.cpp
+8
-10
topological_sort.cpp
test/topological_sort.cpp
+7
-7
util.cpp
test/util.cpp
+1
-3
No files found.
src/ngraph/common.hpp
View file @
064fb0fc
...
@@ -15,8 +15,8 @@
...
@@ -15,8 +15,8 @@
#pragma once
#pragma once
#include <memory>
#include <memory>
#include <vector>
#include <set>
#include <set>
#include <vector>
// Names for types that aren't worth giving their own classes
// Names for types that aren't worth giving their own classes
namespace
ngraph
namespace
ngraph
...
@@ -25,7 +25,7 @@ namespace ngraph
...
@@ -25,7 +25,7 @@ namespace ngraph
class
Parameter
;
class
Parameter
;
class
ValueType
;
class
ValueType
;
template
<
typename
T
,
typename
...
A
>
template
<
typename
T
,
typename
...
A
>
std
::
shared_ptr
<
T
>
node
(
A
&&
...
args
)
std
::
shared_ptr
<
T
>
node
(
A
&&
...
args
)
{
{
return
std
::
make_shared
<
T
>
(
args
...);
return
std
::
make_shared
<
T
>
(
args
...);
...
@@ -33,13 +33,13 @@ namespace ngraph
...
@@ -33,13 +33,13 @@ namespace ngraph
/// Zero or more value types
/// Zero or more value types
using
ValueTypes
=
std
::
vector
<
std
::
shared_ptr
<
ValueType
>>
;
using
ValueTypes
=
std
::
vector
<
std
::
shared_ptr
<
ValueType
>>
;
/// Zero or more nodes
/// Zero or more nodes
using
Nodes
=
std
::
vector
<
std
::
shared_ptr
<
Node
>>
;
using
Nodes
=
std
::
vector
<
std
::
shared_ptr
<
Node
>>
;
/// A sequence of axes
/// A sequence of axes
using
AxisVector
=
std
::
vector
<
size_t
>
;
using
AxisVector
=
std
::
vector
<
size_t
>
;
/// A set of axes, for example, reduction axes
/// A set of axes, for example, reduction axes
using
AxisSet
=
std
::
set
<
size_t
>
;
using
AxisSet
=
std
::
set
<
size_t
>
;
...
...
src/ngraph/element_type.hpp
View file @
064fb0fc
...
@@ -43,27 +43,32 @@ namespace ngraph
...
@@ -43,27 +43,32 @@ namespace ngraph
bool
operator
==
(
const
Type
&
other
)
const
;
bool
operator
==
(
const
Type
&
other
)
const
;
bool
operator
!=
(
const
Type
&
other
)
const
{
return
!
(
*
this
==
other
);
}
bool
operator
!=
(
const
Type
&
other
)
const
{
return
!
(
*
this
==
other
);
}
private
:
private
:
static
std
::
map
<
std
::
string
,
Type
>
m_element_list
;
static
std
::
map
<
std
::
string
,
Type
>
m_element_list
;
size_t
m_bitwidth
;
size_t
m_bitwidth
;
bool
m_is_float
;
bool
m_is_float
;
bool
m_is_signed
;
bool
m_is_signed
;
const
std
::
string
m_cname
;
const
std
::
string
m_cname
;
};
};
// Provides a compile-time name for a C++ type.
// Provides a compile-time name for a C++ type.
// Used in TraitedType for the string that supplies the C++ type name during code generation,
// Used in TraitedType for the string that supplies the C++ type name during code generation,
// so it needs to be a valid C++ name.
// so it needs to be a valid C++ name.
template
<
typename
T
>
template
<
typename
T
>
const
char
*
traited_type_name
()
const
char
*
traited_type_name
()
{
{
throw
ngraph_error
(
"Unknown type"
);
throw
ngraph_error
(
"Unknown type"
);
}
}
// Define a type string for a type T. Will make traited_type_name<T>() return "T"
// Define a type string for a type T. Will make traited_type_name<T>() return "T"
#define NGRAPH_DEFINE_TTN( T ) \
#define NGRAPH_DEFINE_TTN(T) \
template<> constexpr const char* traited_type_name < T > () { return #T; }
template <> \
constexpr const char* traited_type_name<T>() \
{ \
return #T; \
}
// Literals (and probably other things we don't know about yet) need to have their C++ types
// Literals (and probably other things we don't know about yet) need to have their C++ types
// and element types coordinated. Every element type corresponds to a TraitedType which provides
// and element types coordinated. Every element type corresponds to a TraitedType which provides
// access to both the instance and the C++ type used to hold the value during compilation.
// access to both the instance and the C++ type used to hold the value during compilation.
...
@@ -72,10 +77,10 @@ namespace ngraph
...
@@ -72,10 +77,10 @@ namespace ngraph
{
{
protected
:
protected
:
TraitedType
()
TraitedType
()
:
Type
(
sizeof
(
T
)
*
8
,
:
Type
(
sizeof
(
T
)
*
8
,
std
::
is_floating_point
<
T
>::
value
,
std
::
is_floating_point
<
T
>::
value
,
std
::
is_signed
<
T
>::
value
,
std
::
is_signed
<
T
>::
value
,
traited_type_name
<
T
>
())
traited_type_name
<
T
>
())
{
{
}
}
...
@@ -83,31 +88,32 @@ namespace ngraph
...
@@ -83,31 +88,32 @@ namespace ngraph
// This is the C++ type used to hold a value of this element type during compilation
// This is the C++ type used to hold a value of this element type during compilation
using
type
=
T
;
using
type
=
T
;
// This returns a reference to an instance of this element type.
// This returns a reference to an instance of this element type.
static
const
TraitedType
<
T
>&
element_type
(){
static
const
TraitedType
<
T
>&
element_type
()
{
static
TraitedType
<
T
>
t
;
static
TraitedType
<
T
>
t
;
return
t
;
return
t
;
}
}
};
};
NGRAPH_DEFINE_TTN
(
float
)
NGRAPH_DEFINE_TTN
(
float
)
using
Float
=
TraitedType
<
float
>
;
using
Float
=
TraitedType
<
float
>
;
NGRAPH_DEFINE_TTN
(
int8_t
)
NGRAPH_DEFINE_TTN
(
int8_t
)
using
Int8
=
TraitedType
<
int8_t
>
;
using
Int8
=
TraitedType
<
int8_t
>
;
NGRAPH_DEFINE_TTN
(
int32_t
)
NGRAPH_DEFINE_TTN
(
int32_t
)
using
Int32
=
TraitedType
<
int32_t
>
;
using
Int32
=
TraitedType
<
int32_t
>
;
NGRAPH_DEFINE_TTN
(
int64_t
)
NGRAPH_DEFINE_TTN
(
int64_t
)
using
Int64
=
TraitedType
<
int64_t
>
;
using
Int64
=
TraitedType
<
int64_t
>
;
NGRAPH_DEFINE_TTN
(
uint8_t
)
NGRAPH_DEFINE_TTN
(
uint8_t
)
using
UInt8
=
TraitedType
<
uint8_t
>
;
using
UInt8
=
TraitedType
<
uint8_t
>
;
NGRAPH_DEFINE_TTN
(
uint32_t
)
NGRAPH_DEFINE_TTN
(
uint32_t
)
using
UInt32
=
TraitedType
<
uint32_t
>
;
using
UInt32
=
TraitedType
<
uint32_t
>
;
NGRAPH_DEFINE_TTN
(
uint64_t
)
NGRAPH_DEFINE_TTN
(
uint64_t
)
using
UInt64
=
TraitedType
<
uint64_t
>
;
using
UInt64
=
TraitedType
<
uint64_t
>
;
}
}
}
}
src/ngraph/function.hpp
View file @
064fb0fc
...
@@ -25,14 +25,15 @@ namespace ngraph
...
@@ -25,14 +25,15 @@ namespace ngraph
class
Function
class
Function
{
{
public
:
public
:
Function
(
const
std
::
shared_ptr
<
Node
>&
result
,
Function
(
const
std
::
shared_ptr
<
Node
>&
result
,
const
std
::
vector
<
std
::
shared_ptr
<
Parameter
>>&
parameters
);
const
std
::
vector
<
std
::
shared_ptr
<
Parameter
>>&
parameters
);
std
::
shared_ptr
<
Node
>
result
()
{
return
m_result
;
}
std
::
shared_ptr
<
Node
>
result
()
{
return
m_result
;
}
std
::
shared_ptr
<
Parameter
>
parameter
(
size_t
i
)
{
return
m_parameters
[
i
];
}
std
::
shared_ptr
<
Parameter
>
parameter
(
size_t
i
)
{
return
m_parameters
[
i
];
}
std
::
string
name
()
const
{
return
m_name
;
}
std
::
string
name
()
const
{
return
m_name
;
}
protected
:
protected
:
std
::
shared_ptr
<
Node
>
m_result
;
std
::
shared_ptr
<
Node
>
m_result
;
std
::
vector
<
std
::
shared_ptr
<
ngraph
::
Parameter
>>
m_parameters
;
std
::
vector
<
std
::
shared_ptr
<
ngraph
::
Parameter
>>
m_parameters
;
std
::
string
m_name
;
std
::
string
m_name
;
};
};
...
@@ -40,10 +41,10 @@ namespace ngraph
...
@@ -40,10 +41,10 @@ namespace ngraph
namespace
op
namespace
op
{
{
std
::
shared_ptr
<
Function
>
std
::
shared_ptr
<
Function
>
function
(
const
std
::
shared_ptr
<
Node
>&
result
,
function
(
const
std
::
shared_ptr
<
Node
>&
result
,
const
std
::
initializer_list
<
std
::
shared_ptr
<
Parameter
>>&
parameters
);
const
std
::
initializer_list
<
std
::
shared_ptr
<
Parameter
>>&
parameters
);
std
::
shared_ptr
<
Function
>
std
::
shared_ptr
<
Function
>
function
(
const
std
::
shared_ptr
<
Node
>&
result
,
function
(
const
std
::
shared_ptr
<
Node
>&
result
,
const
std
::
vector
<
std
::
shared_ptr
<
Parameter
>>&
parameters
);
const
std
::
vector
<
std
::
shared_ptr
<
Parameter
>>&
parameters
);
}
}
}
}
src/ngraph/node.cpp
View file @
064fb0fc
...
@@ -17,7 +17,8 @@
...
@@ -17,7 +17,8 @@
size_t
ngraph
::
Node
::
m_next_instance_id
=
0
;
size_t
ngraph
::
Node
::
m_next_instance_id
=
0
;
ngraph
::
Node
::
Node
(
const
std
::
vector
<
std
::
shared_ptr
<
Node
>>&
arguments
,
std
::
shared_ptr
<
ValueType
>
type
)
ngraph
::
Node
::
Node
(
const
std
::
vector
<
std
::
shared_ptr
<
Node
>>&
arguments
,
std
::
shared_ptr
<
ValueType
>
type
)
:
TypedValueMixin
(
type
)
:
TypedValueMixin
(
type
)
,
m_arguments
(
arguments
)
,
m_arguments
(
arguments
)
,
m_instance_id
(
m_next_instance_id
++
)
,
m_instance_id
(
m_next_instance_id
++
)
...
...
src/ngraph/node.hpp
View file @
064fb0fc
...
@@ -20,8 +20,8 @@
...
@@ -20,8 +20,8 @@
#include <iostream>
#include <iostream>
#include "type.hpp"
#include "common.hpp"
#include "common.hpp"
#include "type.hpp"
namespace
ngraph
namespace
ngraph
{
{
...
@@ -32,11 +32,11 @@ namespace ngraph
...
@@ -32,11 +32,11 @@ namespace ngraph
/// view or a (possibly empty) tuple of values.
/// view or a (possibly empty) tuple of values.
class
Node
:
public
TypedValueMixin
,
public
std
::
enable_shared_from_this
<
Node
>
class
Node
:
public
TypedValueMixin
,
public
std
::
enable_shared_from_this
<
Node
>
{
{
protected
:
protected
:
Node
(
const
Nodes
&
arguments
,
std
::
shared_ptr
<
ValueType
>
type
=
nullptr
);
Node
(
const
Nodes
&
arguments
,
std
::
shared_ptr
<
ValueType
>
type
=
nullptr
);
virtual
~
Node
()
{}
virtual
~
Node
()
{}
public
:
public
:
/// A "one-liner" describing this node.
/// A "one-liner" describing this node.
virtual
std
::
string
description
()
const
=
0
;
virtual
std
::
string
description
()
const
=
0
;
...
@@ -68,10 +68,10 @@ namespace ngraph
...
@@ -68,10 +68,10 @@ namespace ngraph
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
,
const
Node
&
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
,
const
Node
&
);
protected
:
protected
:
Nodes
m_arguments
;
Nodes
m_arguments
;
std
::
multiset
<
Node
*>
m_users
;
std
::
multiset
<
Node
*>
m_users
;
std
::
string
m_name
;
std
::
string
m_name
;
size_t
m_instance_id
;
size_t
m_instance_id
;
static
size_t
m_next_instance_id
;
static
size_t
m_next_instance_id
;
};
};
}
}
src/ngraph/op.hpp
View file @
064fb0fc
...
@@ -24,39 +24,51 @@ namespace ngraph
...
@@ -24,39 +24,51 @@ namespace ngraph
{
{
namespace
op
namespace
op
{
{
std
::
shared_ptr
<
Node
>
abs
(
const
std
::
shared_ptr
<
Node
>&
arg
);
std
::
shared_ptr
<
Node
>
abs
(
const
std
::
shared_ptr
<
Node
>&
arg
);
std
::
shared_ptr
<
Node
>
add
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
);
std
::
shared_ptr
<
Node
>
add
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
std
::
shared_ptr
<
Node
>
ceiling
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
);
const
std
::
shared_ptr
<
Node
>&
arg1
);
std
::
shared_ptr
<
Node
>
ceiling
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
);
//std::shared_ptr<Node> convert();
//std::shared_ptr<Node> convert();
//std::shared_ptr<Node> convolution();
//std::shared_ptr<Node> convolution();
std
::
shared_ptr
<
Node
>
divide
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
);
std
::
shared_ptr
<
Node
>
divide
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
std
::
shared_ptr
<
Node
>
equal
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
);
const
std
::
shared_ptr
<
Node
>&
arg1
);
std
::
shared_ptr
<
Node
>
equal
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
);
std
::
shared_ptr
<
Node
>
exp
(
const
std
::
shared_ptr
<
Node
>&
arg0
);
std
::
shared_ptr
<
Node
>
exp
(
const
std
::
shared_ptr
<
Node
>&
arg0
);
std
::
shared_ptr
<
Node
>
floor
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
);
std
::
shared_ptr
<
Node
>
floor
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
);
//std::shared_ptr<Node> get_tuple_element();
//std::shared_ptr<Node> get_tuple_element();
std
::
shared_ptr
<
Node
>
greater
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
);
std
::
shared_ptr
<
Node
>
greater
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
);
//std::shared_ptr<Node> greater_equal(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
//std::shared_ptr<Node> greater_equal(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
std
::
shared_ptr
<
Node
>
less
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
);
std
::
shared_ptr
<
Node
>
less
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
);
//std::shared_ptr<Node> less_equal(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
//std::shared_ptr<Node> less_equal(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
std
::
shared_ptr
<
Node
>
log
(
const
std
::
shared_ptr
<
Node
>&
arg0
);
std
::
shared_ptr
<
Node
>
log
(
const
std
::
shared_ptr
<
Node
>&
arg0
);
//std::shared_ptr<Node> logical(); and, or, not
//std::shared_ptr<Node> logical(); and, or, not
std
::
shared_ptr
<
Node
>
maximum
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
);
std
::
shared_ptr
<
Node
>
maximum
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
std
::
shared_ptr
<
Node
>
minimum
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
);
const
std
::
shared_ptr
<
Node
>&
arg1
);
std
::
shared_ptr
<
Node
>
multiply
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
);
std
::
shared_ptr
<
Node
>
minimum
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
);
std
::
shared_ptr
<
Node
>
multiply
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
);
std
::
shared_ptr
<
Node
>
negative
(
const
std
::
shared_ptr
<
Node
>&
arg0
);
std
::
shared_ptr
<
Node
>
negative
(
const
std
::
shared_ptr
<
Node
>&
arg0
);
//std::shared_ptr<Node> pad();
//std::shared_ptr<Node> pad();
std
::
shared_ptr
<
Node
>
power
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
);
std
::
shared_ptr
<
Node
>
power
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
);
//std::shared_ptr<Node> reduce();
//std::shared_ptr<Node> reduce();
// std::shared_ptr<Node> reduce_window();
// std::shared_ptr<Node> reduce_window();
std
::
shared_ptr
<
Node
>
remainder
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
);
std
::
shared_ptr
<
Node
>
remainder
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
);
std
::
shared_ptr
<
Node
>
reshape
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
Shape
&
shape
);
std
::
shared_ptr
<
Node
>
reshape
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
Shape
&
shape
);
//std::shared_ptr<Node> reverse();
//std::shared_ptr<Node> reverse();
//std::shared_ptr<Node> rng();
//std::shared_ptr<Node> rng();
//std::shared_ptr<Node> select();
//std::shared_ptr<Node> select();
//std::shared_ptr<Node> select_scatter();
//std::shared_ptr<Node> select_scatter();
//std::shared_ptr<Node> slice();
//std::shared_ptr<Node> slice();
std
::
shared_ptr
<
Node
>
subtract
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
);
std
::
shared_ptr
<
Node
>
subtract
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
);
//std::shared_ptr<Node> transpose();
//std::shared_ptr<Node> transpose();
//std::shared_ptr<Node> while();
//std::shared_ptr<Node> while();
}
}
...
@@ -81,6 +93,7 @@ namespace ngraph
...
@@ -81,6 +93,7 @@ namespace ngraph
class
FunctionOp
:
public
Op
class
FunctionOp
:
public
Op
{
{
virtual
std
::
string
description
()
const
override
{
return
"FunctionOp"
;
}
virtual
std
::
string
description
()
const
override
{
return
"FunctionOp"
;
}
protected
:
protected
:
std
::
shared_ptr
<
Node
>
m_function
;
std
::
shared_ptr
<
Node
>
m_function
;
};
};
...
@@ -95,6 +108,7 @@ namespace ngraph
...
@@ -95,6 +108,7 @@ namespace ngraph
// TODO: Implement for each op. This enables graphs to be built for now.
// TODO: Implement for each op. This enables graphs to be built for now.
virtual
void
propagate_types
()
override
{}
virtual
void
propagate_types
()
override
{}
protected
:
protected
:
BuiltinOp
(
const
std
::
vector
<
std
::
shared_ptr
<
Node
>>&
args
)
BuiltinOp
(
const
std
::
vector
<
std
::
shared_ptr
<
Node
>>&
args
)
:
Op
(
args
)
:
Op
(
args
)
...
...
src/ngraph/ops/broadcast.hpp
View file @
064fb0fc
...
@@ -25,7 +25,9 @@ namespace ngraph
...
@@ -25,7 +25,9 @@ namespace ngraph
/// @param broadcast_axes The axis positions (0-based) in the result that are being broadcast.
/// @param broadcast_axes The axis positions (0-based) in the result that are being broadcast.
/// the remaining axes in shape must be the same as the shape of arg.
/// the remaining axes in shape must be the same as the shape of arg.
///
///
BroadcastOp
(
const
std
::
shared_ptr
<
Node
>&
arg
,
const
Shape
&
shape
,
const
AxisSet
&
broadcast_axes
)
BroadcastOp
(
const
std
::
shared_ptr
<
Node
>&
arg
,
const
Shape
&
shape
,
const
AxisSet
&
broadcast_axes
)
:
BuiltinOp
({
arg
})
:
BuiltinOp
({
arg
})
,
m_shape
(
shape
)
,
m_shape
(
shape
)
,
m_broadcast_axes
(
broadcast_axes
)
,
m_broadcast_axes
(
broadcast_axes
)
...
@@ -36,14 +38,14 @@ namespace ngraph
...
@@ -36,14 +38,14 @@ namespace ngraph
virtual
void
propagate_types
()
override
;
virtual
void
propagate_types
()
override
;
protected
:
protected
:
Shape
m_shape
;
Shape
m_shape
;
AxisSet
m_broadcast_axes
;
AxisSet
m_broadcast_axes
;
};
};
namespace
op
namespace
op
{
{
std
::
shared_ptr
<
Node
>
broadcast
(
const
std
::
shared_ptr
<
Node
>&
tensor
,
std
::
shared_ptr
<
Node
>
broadcast
(
const
std
::
shared_ptr
<
Node
>&
tensor
,
const
Shape
&
shape
,
const
Shape
&
shape
,
AxisSet
&&
broadcast_axes
);
AxisSet
&&
broadcast_axes
);
}
}
}
}
src/ngraph/ops/constant.hpp
View file @
064fb0fc
...
@@ -56,7 +56,7 @@ namespace ngraph
...
@@ -56,7 +56,7 @@ namespace ngraph
ss
<<
description
()
<<
"_"
/* << node_id() */
;
ss
<<
description
()
<<
"_"
/* << node_id() */
;
return
ss
.
str
();
return
ss
.
str
();
}
}
typename
T
::
type
get_value
()
const
{
return
m_value
;
}
typename
T
::
type
get_value
()
const
{
return
m_value
;
}
protected
:
protected
:
...
...
src/ngraph/ops/convert.hpp
View file @
064fb0fc
...
@@ -16,7 +16,6 @@
...
@@ -16,7 +16,6 @@
namespace
ngraph
namespace
ngraph
{
{
class
ConvertOp
:
public
BuiltinOp
class
ConvertOp
:
public
BuiltinOp
{
{
public
:
public
:
...
@@ -28,13 +27,14 @@ namespace ngraph
...
@@ -28,13 +27,14 @@ namespace ngraph
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"convert"
;
}
virtual
std
::
string
get_op_class_name
()
const
override
{
return
"convert"
;
}
virtual
void
propagate_types
()
override
;
virtual
void
propagate_types
()
override
;
protected
:
protected
:
const
ngraph
::
element
::
Type
&
m_element_type
;
const
ngraph
::
element
::
Type
&
m_element_type
;
};
};
namespace
op
namespace
op
{
{
std
::
shared_ptr
<
ngraph
::
ConvertOp
>
convert
(
const
std
::
shared_ptr
<
Node
>&
arg
,
const
ngraph
::
element
::
Type
&
element_type
);
std
::
shared_ptr
<
ngraph
::
ConvertOp
>
convert
(
const
std
::
shared_ptr
<
Node
>&
arg
,
const
ngraph
::
element
::
Type
&
element_type
);
}
}
}
}
src/ngraph/ops/dot.hpp
View file @
064fb0fc
...
@@ -31,6 +31,7 @@ namespace ngraph
...
@@ -31,6 +31,7 @@ namespace ngraph
namespace
op
namespace
op
{
{
std
::
shared_ptr
<
Node
>
dot
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
);
std
::
shared_ptr
<
Node
>
dot
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
);
}
}
}
}
src/ngraph/ops/parameter.hpp
View file @
064fb0fc
...
@@ -51,9 +51,10 @@ namespace ngraph
...
@@ -51,9 +51,10 @@ namespace ngraph
namespace
op
namespace
op
{
{
/// Factory for frameworks
/// Factory for frameworks
std
::
shared_ptr
<
ngraph
::
Parameter
>
parameter
(
const
std
::
shared_ptr
<
ValueType
>&
value_type
=
nullptr
);
std
::
shared_ptr
<
ngraph
::
Parameter
>
parameter
(
const
std
::
shared_ptr
<
ValueType
>&
value_type
=
nullptr
);
/// Convenience factory for tests
/// Convenience factory for tests
std
::
shared_ptr
<
ngraph
::
Parameter
>
parameter
(
const
element
::
Type
element_type
,
std
::
shared_ptr
<
ngraph
::
Parameter
>
parameter
(
const
element
::
Type
element_type
,
const
Shape
&
shape
);
const
Shape
&
shape
);
}
}
}
}
src/ngraph/shape.hpp
View file @
064fb0fc
...
@@ -36,9 +36,10 @@ namespace ngraph
...
@@ -36,9 +36,10 @@ namespace ngraph
}
}
/// Conversion to a vector of sizes.
/// Conversion to a vector of sizes.
operator
const
std
::
vector
<
size_t
>&
()
const
{
return
m_sizes
;
}
operator
const
std
::
vector
<
size_t
>&
()
const
{
return
m_sizes
;
}
bool
operator
==
(
const
Shape
&
shape
)
const
{
return
m_sizes
==
shape
.
m_sizes
;
}
bool
operator
==
(
const
Shape
&
shape
)
const
{
return
m_sizes
==
shape
.
m_sizes
;
}
bool
operator
!=
(
const
Shape
&
shape
)
const
{
return
m_sizes
!=
shape
.
m_sizes
;
}
bool
operator
!=
(
const
Shape
&
shape
)
const
{
return
m_sizes
!=
shape
.
m_sizes
;
}
protected
:
protected
:
std
::
vector
<
size_t
>
m_sizes
;
std
::
vector
<
size_t
>
m_sizes
;
};
};
...
...
src/ngraph/topological_sort.cpp
View file @
064fb0fc
...
@@ -12,8 +12,8 @@
...
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
// ----------------------------------------------------------------------------
#include "node.hpp"
#include "topological_sort.hpp"
#include "topological_sort.hpp"
#include "node.hpp"
#include "util.hpp"
#include "util.hpp"
using
namespace
ngraph
;
using
namespace
ngraph
;
...
@@ -21,16 +21,16 @@ using namespace std;
...
@@ -21,16 +21,16 @@ using namespace std;
void
ngraph
::
TopologicalSort
::
promote_node
(
Node
*
n
)
void
ngraph
::
TopologicalSort
::
promote_node
(
Node
*
n
)
{
{
for
(
auto
dn
=
m_dependent_nodes
.
begin
();
dn
!=
m_dependent_nodes
.
end
();
dn
++
)
for
(
auto
dn
=
m_dependent_nodes
.
begin
();
dn
!=
m_dependent_nodes
.
end
();
dn
++
)
{
{
if
(
dn
->
first
>
0
)
// Skip zero as they should never be promoted
if
(
dn
->
first
>
0
)
// Skip zero as they should never be promoted
{
{
auto
it
=
find
(
dn
->
second
.
begin
(),
dn
->
second
.
end
(),
n
);
auto
it
=
find
(
dn
->
second
.
begin
(),
dn
->
second
.
end
(),
n
);
if
(
it
!=
dn
->
second
.
end
())
if
(
it
!=
dn
->
second
.
end
())
{
{
// found the node
// found the node
dn
->
second
.
erase
(
it
);
dn
->
second
.
erase
(
it
);
m_dependent_nodes
[
dn
->
first
-
1
].
push_back
(
n
);
m_dependent_nodes
[
dn
->
first
-
1
].
push_back
(
n
);
}
}
}
}
}
}
...
@@ -38,8 +38,7 @@ void ngraph::TopologicalSort::promote_node(Node* n)
...
@@ -38,8 +38,7 @@ void ngraph::TopologicalSort::promote_node(Node* n)
void
ngraph
::
TopologicalSort
::
process
(
node_ptr
p
)
void
ngraph
::
TopologicalSort
::
process
(
node_ptr
p
)
{
{
traverse_nodes
(
p
,
[
&
](
node_ptr
node
)
traverse_nodes
(
p
,
[
&
](
node_ptr
node
)
{
{
list
<
Node
*>&
node_list
=
m_dependent_nodes
[
node
->
get_arguments
().
size
()];
list
<
Node
*>&
node_list
=
m_dependent_nodes
[
node
->
get_arguments
().
size
()];
node_list
.
push_back
(
node
.
get
());
node_list
.
push_back
(
node
.
get
());
});
});
...
...
src/ngraph/topological_sort.hpp
View file @
064fb0fc
...
@@ -14,9 +14,10 @@
...
@@ -14,9 +14,10 @@
#pragma once
#pragma once
#include <memory>
#include <map>
#include <list>
#include <list>
#include <map>
#include <memory>
#include <vector>
namespace
ngraph
namespace
ngraph
{
{
...
@@ -30,7 +31,7 @@ class ngraph::TopologicalSort
...
@@ -30,7 +31,7 @@ class ngraph::TopologicalSort
public
:
public
:
TopologicalSort
()
{}
TopologicalSort
()
{}
void
process
(
node_ptr
);
void
process
(
node_ptr
);
const
std
::
vector
<
Node
*>&
get_sorted_list
()
const
;
const
std
::
vector
<
Node
*>&
get_sorted_list
()
const
;
private
:
private
:
...
...
src/ngraph/type.hpp
View file @
064fb0fc
...
@@ -33,7 +33,7 @@ namespace ngraph
...
@@ -33,7 +33,7 @@ namespace ngraph
public
:
public
:
virtual
~
ValueType
()
{}
virtual
~
ValueType
()
{}
virtual
bool
operator
==
(
const
std
::
shared_ptr
<
ValueType
>&
that
)
const
=
0
;
virtual
bool
operator
==
(
const
std
::
shared_ptr
<
ValueType
>&
that
)
const
=
0
;
bool
operator
!=
(
const
std
::
shared_ptr
<
ValueType
>&
that
)
const
{
return
!
(
*
this
==
that
);
}
bool
operator
!=
(
const
std
::
shared_ptr
<
ValueType
>&
that
)
const
{
return
!
(
*
this
==
that
);
}
};
};
/// Describes a tensor view; an element type and a shape.
/// Describes a tensor view; an element type and a shape.
...
@@ -71,8 +71,11 @@ namespace ngraph
...
@@ -71,8 +71,11 @@ namespace ngraph
{
{
}
}
const
std
::
vector
<
std
::
shared_ptr
<
ValueType
>>
get_element_types
()
const
{
return
m_element_types
;
}
const
std
::
vector
<
std
::
shared_ptr
<
ValueType
>>
get_element_types
()
const
std
::
vector
<
std
::
shared_ptr
<
ValueType
>>
set_element_types
()
{
return
m_element_types
;
}
{
return
m_element_types
;
}
std
::
vector
<
std
::
shared_ptr
<
ValueType
>>
set_element_types
()
{
return
m_element_types
;
}
virtual
bool
operator
==
(
const
std
::
shared_ptr
<
ValueType
>&
that
)
const
override
;
virtual
bool
operator
==
(
const
std
::
shared_ptr
<
ValueType
>&
that
)
const
override
;
...
@@ -95,7 +98,10 @@ namespace ngraph
...
@@ -95,7 +98,10 @@ namespace ngraph
** Set the type
** Set the type
** /param type The new type
** /param type The new type
**/
**/
void
set_value_type
(
const
std
::
shared_ptr
<
ValueType
>&
value_type
)
{
m_value_type
=
value_type
;
}
void
set_value_type
(
const
std
::
shared_ptr
<
ValueType
>&
value_type
)
{
m_value_type
=
value_type
;
}
/**
/**
** Set the type to be a tensor view type
** Set the type to be a tensor view type
** /param element_type The type of the tensor elements
** /param element_type The type of the tensor elements
...
@@ -114,6 +120,7 @@ namespace ngraph
...
@@ -114,6 +120,7 @@ namespace ngraph
** The type associated with this value.
** The type associated with this value.
**/
**/
const
std
::
shared_ptr
<
ValueType
>
get_value_type
()
const
{
return
m_value_type
;
}
const
std
::
shared_ptr
<
ValueType
>
get_value_type
()
const
{
return
m_value_type
;
}
protected
:
protected
:
std
::
shared_ptr
<
ValueType
>
m_value_type
;
std
::
shared_ptr
<
ValueType
>
m_value_type
;
};
};
...
...
src/ngraph/visualize.cpp
View file @
064fb0fc
...
@@ -12,13 +12,13 @@
...
@@ -12,13 +12,13 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
// ----------------------------------------------------------------------------
#include <list>
#include <fstream>
#include <cstdio>
#include <cstdio>
#include <fstream>
#include <list>
#include "visualize.hpp"
#include "ngraph/node.hpp"
#include "ngraph/node.hpp"
#include "util.hpp"
#include "util.hpp"
#include "visualize.hpp"
using
namespace
ngraph
;
using
namespace
ngraph
;
using
namespace
std
;
using
namespace
std
;
...
@@ -31,8 +31,7 @@ Visualize::Visualize(const string& name)
...
@@ -31,8 +31,7 @@ Visualize::Visualize(const string& name)
void
Visualize
::
add
(
node_ptr
p
)
void
Visualize
::
add
(
node_ptr
p
)
{
{
// map<size_t, list<node_ptr>> dependent_nodes;
// map<size_t, list<node_ptr>> dependent_nodes;
traverse_nodes
(
p
,
[
&
](
node_ptr
node
)
traverse_nodes
(
p
,
[
&
](
node_ptr
node
)
{
{
for
(
auto
arg
:
node
->
get_arguments
())
for
(
auto
arg
:
node
->
get_arguments
())
{
{
m_ss
<<
" "
<<
arg
->
get_node_id
()
<<
" -> "
<<
node
->
get_node_id
()
<<
";
\n
"
;
m_ss
<<
" "
<<
arg
->
get_node_id
()
<<
" -> "
<<
node
->
get_node_id
()
<<
";
\n
"
;
...
@@ -42,7 +41,7 @@ void Visualize::add(node_ptr p)
...
@@ -42,7 +41,7 @@ void Visualize::add(node_ptr p)
void
Visualize
::
save_dot
(
const
string
&
path
)
const
void
Visualize
::
save_dot
(
const
string
&
path
)
const
{
{
auto
tmp_file
=
path
+
".tmp"
;
auto
tmp_file
=
path
+
".tmp"
;
ofstream
out
(
tmp_file
);
ofstream
out
(
tmp_file
);
if
(
out
)
if
(
out
)
{
{
...
@@ -53,7 +52,7 @@ void Visualize::save_dot(const string& path) const
...
@@ -53,7 +52,7 @@ void Visualize::save_dot(const string& path) const
stringstream
ss
;
stringstream
ss
;
ss
<<
"dot -Tpng "
<<
tmp_file
<<
" -o "
<<
path
;
ss
<<
"dot -Tpng "
<<
tmp_file
<<
" -o "
<<
path
;
auto
cmd
=
ss
.
str
();
auto
cmd
=
ss
.
str
();
auto
stream
=
popen
(
cmd
.
c_str
(),
"r"
);
auto
stream
=
popen
(
cmd
.
c_str
(),
"r"
);
pclose
(
stream
);
pclose
(
stream
);
...
...
src/ops/broadcast.cpp
View file @
064fb0fc
...
@@ -21,9 +21,9 @@ using namespace ngraph;
...
@@ -21,9 +21,9 @@ using namespace ngraph;
/// @param shape The shape of the result
/// @param shape The shape of the result
/// @param broadcast_axes The axis positions (0-based) in the result that are being broadcast.
/// @param broadcast_axes The axis positions (0-based) in the result that are being broadcast.
/// the remaining axes in shape must be the same as the shape of arg.
/// the remaining axes in shape must be the same as the shape of arg.
std
::
shared_ptr
<
Node
>
ngraph
::
op
::
broadcast
(
const
std
::
shared_ptr
<
Node
>&
tensor
,
std
::
shared_ptr
<
Node
>
ngraph
::
op
::
broadcast
(
const
std
::
shared_ptr
<
Node
>&
tensor
,
const
Shape
&
shape
,
const
Shape
&
shape
,
AxisSet
&&
broadcast_axes
)
AxisSet
&&
broadcast_axes
)
{
{
return
make_shared
<
BroadcastOp
>
(
tensor
,
shape
,
broadcast_axes
);
return
make_shared
<
BroadcastOp
>
(
tensor
,
shape
,
broadcast_axes
);
}
}
...
...
src/ops/convert.cpp
View file @
064fb0fc
...
@@ -24,7 +24,8 @@ void ConvertOp::propagate_types()
...
@@ -24,7 +24,8 @@ void ConvertOp::propagate_types()
throw
ngraph_error
(
"NIY"
);
throw
ngraph_error
(
"NIY"
);
}
}
shared_ptr
<
ConvertOp
>
op
::
convert
(
const
std
::
shared_ptr
<
Node
>&
arg
,
const
element
::
Type
&
element_type
)
shared_ptr
<
ConvertOp
>
op
::
convert
(
const
std
::
shared_ptr
<
Node
>&
arg
,
const
element
::
Type
&
element_type
)
{
{
return
make_shared
<
ConvertOp
>
(
arg
,
element_type
);
return
make_shared
<
ConvertOp
>
(
arg
,
element_type
);
}
}
src/ops/dot.cpp
View file @
064fb0fc
...
@@ -20,15 +20,18 @@ using namespace std;
...
@@ -20,15 +20,18 @@ using namespace std;
using
namespace
ngraph
;
using
namespace
ngraph
;
/// TODO: Semantics of arg0 and arg1 axes wrt reduction.
/// TODO: Semantics of arg0 and arg1 axes wrt reduction.
std
::
shared_ptr
<
Node
>
ngraph
::
op
::
dot
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
std
::
shared_ptr
<
Node
>
ngraph
::
op
::
dot
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
{
{
return
make_shared
<
DotOp
>
(
arg0
,
arg1
);
return
make_shared
<
DotOp
>
(
arg0
,
arg1
);
}
}
void
DotOp
::
propagate_types
()
void
DotOp
::
propagate_types
()
{
{
auto
arg0_tensor_type
=
dynamic_pointer_cast
<
TensorViewType
>
(
m_arguments
.
at
(
0
)
->
get_value_type
());
auto
arg0_tensor_type
=
auto
arg1_tensor_type
=
dynamic_pointer_cast
<
TensorViewType
>
(
m_arguments
.
at
(
1
)
->
get_value_type
());
dynamic_pointer_cast
<
TensorViewType
>
(
m_arguments
.
at
(
0
)
->
get_value_type
());
auto
arg1_tensor_type
=
dynamic_pointer_cast
<
TensorViewType
>
(
m_arguments
.
at
(
1
)
->
get_value_type
());
if
(
nullptr
==
arg0_tensor_type
||
nullptr
==
arg1_tensor_type
)
if
(
nullptr
==
arg0_tensor_type
||
nullptr
==
arg1_tensor_type
)
{
{
throw
ngraph_error
(
"Arguments to dot must be tensor views"
);
throw
ngraph_error
(
"Arguments to dot must be tensor views"
);
...
...
src/ops/function.cpp
View file @
064fb0fc
...
@@ -17,7 +17,7 @@
...
@@ -17,7 +17,7 @@
using
namespace
std
;
using
namespace
std
;
using
namespace
ngraph
;
using
namespace
ngraph
;
Function
::
Function
(
const
std
::
shared_ptr
<
Node
>&
result
,
Function
::
Function
(
const
std
::
shared_ptr
<
Node
>&
result
,
const
std
::
vector
<
std
::
shared_ptr
<
ngraph
::
Parameter
>>&
parameters
)
const
std
::
vector
<
std
::
shared_ptr
<
ngraph
::
Parameter
>>&
parameters
)
:
m_result
(
result
)
:
m_result
(
result
)
,
m_parameters
(
parameters
)
,
m_parameters
(
parameters
)
...
@@ -30,13 +30,13 @@ Function::Function(const std::shared_ptr<Node>&
...
@@ -30,13 +30,13 @@ Function::Function(const std::shared_ptr<Node>&
}
}
}
}
shared_ptr
<
Function
>
ngraph
::
op
::
function
(
const
std
::
shared_ptr
<
Node
>&
result
,
shared_ptr
<
Function
>
ngraph
::
op
::
function
(
const
std
::
shared_ptr
<
Node
>&
result
,
const
initializer_list
<
shared_ptr
<
Parameter
>>&
parameters
)
const
initializer_list
<
shared_ptr
<
Parameter
>>&
parameters
)
{
{
return
make_shared
<
Function
>
(
result
,
parameters
);
return
make_shared
<
Function
>
(
result
,
parameters
);
}
}
shared_ptr
<
Function
>
ngraph
::
op
::
function
(
const
std
::
shared_ptr
<
Node
>&
result
,
shared_ptr
<
Function
>
ngraph
::
op
::
function
(
const
std
::
shared_ptr
<
Node
>&
result
,
const
vector
<
shared_ptr
<
Parameter
>>&
parameters
)
const
vector
<
shared_ptr
<
Parameter
>>&
parameters
)
{
{
return
make_shared
<
Function
>
(
result
,
parameters
);
return
make_shared
<
Function
>
(
result
,
parameters
);
...
...
src/ops/op.cpp
View file @
064fb0fc
...
@@ -32,12 +32,14 @@ std::shared_ptr<Node> ngraph::op::abs(const std::shared_ptr<Node>& arg)
...
@@ -32,12 +32,14 @@ std::shared_ptr<Node> ngraph::op::abs(const std::shared_ptr<Node>& arg)
return
make_shared
<
AbsOp
>
(
arg
);
return
make_shared
<
AbsOp
>
(
arg
);
}
}
std
::
shared_ptr
<
Node
>
ngraph
::
op
::
add
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
std
::
shared_ptr
<
Node
>
ngraph
::
op
::
add
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
{
{
return
make_shared
<
AddOp
>
(
arg0
,
arg1
);
return
make_shared
<
AddOp
>
(
arg0
,
arg1
);
}
}
std
::
shared_ptr
<
Node
>
ngraph
::
op
::
ceiling
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
std
::
shared_ptr
<
Node
>
ngraph
::
op
::
ceiling
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
{
{
return
make_shared
<
CeilingOp
>
(
arg0
,
arg1
);
return
make_shared
<
CeilingOp
>
(
arg0
,
arg1
);
}
}
...
@@ -45,7 +47,8 @@ std::shared_ptr<Node> ngraph::op::ceiling(const std::shared_ptr<Node>& arg0, con
...
@@ -45,7 +47,8 @@ std::shared_ptr<Node> ngraph::op::ceiling(const std::shared_ptr<Node>& arg0, con
// 'convert',
// 'convert',
// 'convolution',
// 'convolution',
std
::
shared_ptr
<
Node
>
ngraph
::
op
::
divide
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
std
::
shared_ptr
<
Node
>
ngraph
::
op
::
divide
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
{
{
return
make_shared
<
DivideOp
>
(
arg0
,
arg1
);
return
make_shared
<
DivideOp
>
(
arg0
,
arg1
);
}
}
...
@@ -55,7 +58,8 @@ std::shared_ptr<Node> ngraph::op::exp(const std::shared_ptr<Node>& arg0)
...
@@ -55,7 +58,8 @@ std::shared_ptr<Node> ngraph::op::exp(const std::shared_ptr<Node>& arg0)
return
make_shared
<
ExpOp
>
(
arg0
);
return
make_shared
<
ExpOp
>
(
arg0
);
}
}
std
::
shared_ptr
<
Node
>
ngraph
::
op
::
floor
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
std
::
shared_ptr
<
Node
>
ngraph
::
op
::
floor
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
{
{
return
make_shared
<
FloorOp
>
(
arg0
,
arg1
);
return
make_shared
<
FloorOp
>
(
arg0
,
arg1
);
}
}
...
@@ -65,17 +69,20 @@ std::shared_ptr<Node> ngraph::op::log(const std::shared_ptr<Node>& arg0)
...
@@ -65,17 +69,20 @@ std::shared_ptr<Node> ngraph::op::log(const std::shared_ptr<Node>& arg0)
return
make_shared
<
LogOp
>
(
arg0
);
return
make_shared
<
LogOp
>
(
arg0
);
}
}
std
::
shared_ptr
<
Node
>
ngraph
::
op
::
maximum
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
std
::
shared_ptr
<
Node
>
ngraph
::
op
::
maximum
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
{
{
return
make_shared
<
MaximumOp
>
(
arg0
,
arg1
);
return
make_shared
<
MaximumOp
>
(
arg0
,
arg1
);
}
}
std
::
shared_ptr
<
Node
>
ngraph
::
op
::
minimum
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
std
::
shared_ptr
<
Node
>
ngraph
::
op
::
minimum
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
{
{
return
make_shared
<
MinimumOp
>
(
arg0
,
arg1
);
return
make_shared
<
MinimumOp
>
(
arg0
,
arg1
);
}
}
std
::
shared_ptr
<
Node
>
ngraph
::
op
::
multiply
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
std
::
shared_ptr
<
Node
>
ngraph
::
op
::
multiply
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
{
{
return
make_shared
<
MultiplyOp
>
(
arg0
,
arg1
);
return
make_shared
<
MultiplyOp
>
(
arg0
,
arg1
);
}
}
...
@@ -87,14 +94,16 @@ std::shared_ptr<Node> ngraph::op::negative(const std::shared_ptr<Node>& arg0)
...
@@ -87,14 +94,16 @@ std::shared_ptr<Node> ngraph::op::negative(const std::shared_ptr<Node>& arg0)
// 'pad',
// 'pad',
std
::
shared_ptr
<
Node
>
ngraph
::
op
::
power
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
std
::
shared_ptr
<
Node
>
ngraph
::
op
::
power
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
{
{
return
make_shared
<
PowerOp
>
(
arg0
,
arg1
);
return
make_shared
<
PowerOp
>
(
arg0
,
arg1
);
}
}
//'reduce',
//'reduce',
std
::
shared_ptr
<
Node
>
ngraph
::
op
::
remainder
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
std
::
shared_ptr
<
Node
>
ngraph
::
op
::
remainder
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
{
{
return
make_shared
<
RemainderOp
>
(
arg0
,
arg1
);
return
make_shared
<
RemainderOp
>
(
arg0
,
arg1
);
}
}
...
@@ -109,7 +118,8 @@ std::shared_ptr<Node> ngraph::op::reshape(const std::shared_ptr<Node>& arg0, con
...
@@ -109,7 +118,8 @@ std::shared_ptr<Node> ngraph::op::reshape(const std::shared_ptr<Node>& arg0, con
// 'select',
// 'select',
//'slice',
//'slice',
std
::
shared_ptr
<
Node
>
ngraph
::
op
::
subtract
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
std
::
shared_ptr
<
Node
>
ngraph
::
op
::
subtract
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
{
{
return
make_shared
<
SubtractOp
>
(
arg0
,
arg1
);
return
make_shared
<
SubtractOp
>
(
arg0
,
arg1
);
}
}
...
...
src/ops/parameter.cpp
View file @
064fb0fc
...
@@ -41,9 +41,7 @@ void Parameter::assign_function(Function* function, size_t index)
...
@@ -41,9 +41,7 @@ void Parameter::assign_function(Function* function, size_t index)
m_index
=
index
;
m_index
=
index
;
}
}
void
Parameter
::
propagate_types
()
void
Parameter
::
propagate_types
()
{}
{
}
shared_ptr
<
Parameter
>
ngraph
::
op
::
parameter
(
const
std
::
shared_ptr
<
ValueType
>&
value_type
)
shared_ptr
<
Parameter
>
ngraph
::
op
::
parameter
(
const
std
::
shared_ptr
<
ValueType
>&
value_type
)
{
{
...
...
test/build_graph.cpp
View file @
064fb0fc
...
@@ -28,7 +28,7 @@ TEST(build_graph, build_simple)
...
@@ -28,7 +28,7 @@ TEST(build_graph, build_simple)
auto
arg2
=
node
<
Parameter
>
(
element
::
Float
::
element_type
(),
Shape
{
32
,
7
});
auto
arg2
=
node
<
Parameter
>
(
element
::
Float
::
element_type
(),
Shape
{
32
,
7
});
auto
arg3
=
node
<
Parameter
>
(
element
::
Float
::
element_type
(),
Shape
{
32
,
7
});
auto
arg3
=
node
<
Parameter
>
(
element
::
Float
::
element_type
(),
Shape
{
32
,
7
});
auto
broadcast_1
=
node
<
BroadcastOp
>
(
arg3
,
Shape
{
10
,
32
,
7
},
AxisSet
{
0
});
auto
broadcast_1
=
node
<
BroadcastOp
>
(
arg3
,
Shape
{
10
,
32
,
7
},
AxisSet
{
0
});
auto
b1
=
node
<
BroadcastOp
>
(
arg3
,
Shape
{
10
,
32
,
7
},
AxisSet
{
0
});
auto
b1
=
node
<
BroadcastOp
>
(
arg3
,
Shape
{
10
,
32
,
7
},
AxisSet
{
0
});
auto
dot
=
node
<
DotOp
>
(
arg2
,
arg0
);
auto
dot
=
node
<
DotOp
>
(
arg2
,
arg0
);
ASSERT_EQ
(
dot
->
get_arguments
()[
0
],
arg2
);
ASSERT_EQ
(
dot
->
get_arguments
()[
0
],
arg2
);
ASSERT_EQ
(
dot
->
get_arguments
()[
1
],
arg0
);
ASSERT_EQ
(
dot
->
get_arguments
()[
1
],
arg0
);
...
@@ -50,7 +50,7 @@ TEST(build_graph, as_type)
...
@@ -50,7 +50,7 @@ TEST(build_graph, as_type)
// Check upcasting a ValueType::ptr that is a TupleType to a TensorViewType and Tuple.
// Check upcasting a ValueType::ptr that is a TupleType to a TensorViewType and Tuple.
auto
tp_vt
=
make_shared
<
TupleType
>
(
ValueTypes
{
tv_vt
,
tv_vt
});
auto
tp_vt
=
make_shared
<
TupleType
>
(
ValueTypes
{
tv_vt
,
tv_vt
});
auto
tp_tv
=
dynamic_pointer_cast
<
TensorViewType
>
(
tp_vt
);
auto
tp_tv
=
dynamic_pointer_cast
<
TensorViewType
>
(
tp_vt
);
ASSERT_EQ
(
nullptr
,
tp_tv
);
ASSERT_EQ
(
nullptr
,
tp_tv
);
auto
tp_tp
=
dynamic_pointer_cast
<
TupleType
>
(
tp_vt
);
auto
tp_tp
=
dynamic_pointer_cast
<
TupleType
>
(
tp_vt
);
ASSERT_EQ
(
tp_vt
,
tp_tp
);
ASSERT_EQ
(
tp_vt
,
tp_tp
);
...
@@ -78,8 +78,8 @@ TEST(build_graph, literal)
...
@@ -78,8 +78,8 @@ TEST(build_graph, literal)
{
{
// float scalar from a float
// float scalar from a float
//auto float0 = FloatScalarConstant::make(3.0);
//auto float0 = FloatScalarConstant::make(3.0);
auto
float0
=
node
<
FloatScalarConstant
>
(
3.0
);
auto
float0
=
node
<
FloatScalarConstant
>
(
3.0
);
auto
float_scalar_type
=
make_shared
<
TensorViewType
>
(
element
::
Float
::
element_type
(),
Shape
{});
auto
float_scalar_type
=
make_shared
<
TensorViewType
>
(
element
::
Float
::
element_type
(),
Shape
{});
ASSERT_EQ
(
float0
->
get_value
(),
3.0
);
ASSERT_EQ
(
float0
->
get_value
(),
3.0
);
ASSERT_EQ
(
*
float0
->
get_value_type
(),
float_scalar_type
);
ASSERT_EQ
(
*
float0
->
get_value_type
(),
float_scalar_type
);
auto
d
=
node
<
DotOp
>
(
float0
,
float0
);
auto
d
=
node
<
DotOp
>
(
float0
,
float0
);
...
@@ -90,15 +90,13 @@ TEST(build_graph, literal)
...
@@ -90,15 +90,13 @@ TEST(build_graph, literal)
auto
float1
=
node
<
FloatScalarConstant
>
(
3
);
auto
float1
=
node
<
FloatScalarConstant
>
(
3
);
ASSERT_EQ
(
float1
->
get_value
(),
3
);
ASSERT_EQ
(
float1
->
get_value
(),
3
);
ASSERT_EQ
(
*
float1
->
get_value_type
(),
float_scalar_type
);
ASSERT_EQ
(
*
float1
->
get_value_type
(),
float_scalar_type
);
auto
int32_0
=
node
<
Int32ScalarConstant
>
(
3.0
);
auto
int32_0
=
node
<
Int32ScalarConstant
>
(
3.0
);
auto
int32_scalar_type
=
make_shared
<
TensorViewType
>
(
element
::
Int32
::
element_type
(),
Shape
{});
auto
int32_scalar_type
=
make_shared
<
TensorViewType
>
(
element
::
Int32
::
element_type
(),
Shape
{});
ASSERT_EQ
(
int32_0
->
get_value
(),
3
);
ASSERT_EQ
(
int32_0
->
get_value
(),
3
);
ASSERT_EQ
(
*
int32_0
->
get_value_type
(),
int32_scalar_type
);
ASSERT_EQ
(
*
int32_0
->
get_value_type
(),
int32_scalar_type
);
ASSERT_NE
(
*
int32_0
->
get_value_type
(),
float_scalar_type
);
ASSERT_NE
(
*
int32_0
->
get_value_type
(),
float_scalar_type
);
}
}
// Check argument inverses
// Check argument inverses
TEST
(
build_graph
,
arg_inverse
)
TEST
(
build_graph
,
arg_inverse
)
{}
{
}
test/topological_sort.cpp
View file @
064fb0fc
...
@@ -29,20 +29,20 @@ using namespace ngraph;
...
@@ -29,20 +29,20 @@ using namespace ngraph;
static
bool
validate_list
(
const
vector
<
Node
*>&
nodes
)
static
bool
validate_list
(
const
vector
<
Node
*>&
nodes
)
{
{
bool
rc
=
true
;
bool
rc
=
true
;
for
(
auto
it
=
nodes
.
rbegin
();
it
!=
nodes
.
rend
();
it
++
)
for
(
auto
it
=
nodes
.
rbegin
();
it
!=
nodes
.
rend
();
it
++
)
{
{
auto
node_tmp
=
*
it
;
auto
node_tmp
=
*
it
;
auto
dependencies_tmp
=
node_tmp
->
get_arguments
();
auto
dependencies_tmp
=
node_tmp
->
get_arguments
();
vector
<
Node
*>
dependencies
;
vector
<
Node
*>
dependencies
;
for
(
shared_ptr
<
Node
>
n
:
dependencies_tmp
)
for
(
shared_ptr
<
Node
>
n
:
dependencies_tmp
)
{
{
dependencies
.
push_back
(
n
.
get
());
dependencies
.
push_back
(
n
.
get
());
}
}
auto
tmp
=
it
+
1
;
auto
tmp
=
it
+
1
;
for
(;
tmp
!=
nodes
.
rend
();
tmp
++
)
for
(;
tmp
!=
nodes
.
rend
();
tmp
++
)
{
{
auto
dep_tmp
=
*
tmp
;
auto
dep_tmp
=
*
tmp
;
auto
found
=
find
(
dependencies
.
begin
(),
dependencies
.
end
(),
dep_tmp
);
auto
found
=
find
(
dependencies
.
begin
(),
dependencies
.
end
(),
dep_tmp
);
if
(
found
!=
dependencies
.
end
())
if
(
found
!=
dependencies
.
end
())
{
{
dependencies
.
erase
(
found
);
dependencies
.
erase
(
found
);
...
@@ -59,7 +59,7 @@ static bool validate_list(const vector<Node*>& nodes)
...
@@ -59,7 +59,7 @@ static bool validate_list(const vector<Node*>& nodes)
TEST
(
topological_sort
,
basic
)
TEST
(
topological_sort
,
basic
)
{
{
vector
<
shared_ptr
<
Parameter
>>
args
;
vector
<
shared_ptr
<
Parameter
>>
args
;
for
(
int
i
=
0
;
i
<
10
;
i
++
)
for
(
int
i
=
0
;
i
<
10
;
i
++
)
{
{
auto
arg
=
op
::
parameter
(
element
::
Float
::
element_type
(),
{
1
});
auto
arg
=
op
::
parameter
(
element
::
Float
::
element_type
(),
{
1
});
ASSERT_NE
(
nullptr
,
arg
);
ASSERT_NE
(
nullptr
,
arg
);
...
...
test/util.cpp
View file @
064fb0fc
...
@@ -134,9 +134,7 @@ TEST(util, contains)
...
@@ -134,9 +134,7 @@ TEST(util, contains)
EXPECT_FALSE
(
contains
(
v1
,
8
));
EXPECT_FALSE
(
contains
(
v1
,
8
));
}
}
TEST
(
util
,
remove_from
)
TEST
(
util
,
remove_from
)
{}
{
}
TEST
(
util
,
reduce
)
TEST
(
util
,
reduce
)
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment