Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
8c16125d
Commit
8c16125d
authored
Aug 17, 2017
by
Scott Cyphers
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'master' into cyphers/view
parents
0064cfd0
8d57ce68
Show whitespace changes
Inline
Side-by-side
Showing
32 changed files
with
774 additions
and
777 deletions
+774
-777
.clang-format
.clang-format
+7
-0
element_type.cpp
src/element_type.cpp
+4
-1
element_type.hpp
src/element_type.hpp
+1
-1
log.cpp
src/log.cpp
+3
-3
log.hpp
src/log.hpp
+1
-1
names.cpp
src/names.cpp
+0
-0
names.hpp
src/names.hpp
+30
-32
strides.cpp
src/strides.cpp
+1
-1
strides.hpp
src/strides.hpp
+1
-3
axes.cpp
src/transformers/axes.cpp
+2
-2
axes.hpp
src/transformers/axes.hpp
+362
-367
exop.cpp
src/transformers/exop.cpp
+2
-2
exop.hpp
src/transformers/exop.hpp
+142
-142
mock.hpp
src/transformers/mock.hpp
+118
-122
mock_transformer.hpp
src/transformers/mock_transformer.hpp
+9
-12
ndarray.hpp
src/transformers/ndarray.hpp
+1
-1
op_graph.cpp
src/transformers/op_graph.cpp
+4
-2
op_graph.hpp
src/transformers/op_graph.hpp
+0
-0
tree.hpp
src/tree.hpp
+2
-3
util.cpp
src/util.cpp
+1
-1
util.hpp
src/util.hpp
+47
-48
uuid.hpp
src/uuid.hpp
+2
-3
axes.cpp
test/axes.cpp
+10
-10
element_type.cpp
test/element_type.cpp
+2
-2
exop.cpp
test/exop.cpp
+2
-2
main.cpp
test/main.cpp
+1
-1
names.cpp
test/names.cpp
+5
-3
op_graph.cpp
test/op_graph.cpp
+2
-2
strides.cpp
test/strides.cpp
+2
-2
tensor.cpp
test/tensor.cpp
+3
-3
util.cpp
test/util.cpp
+5
-3
uuid.cpp
test/uuid.cpp
+2
-2
No files found.
.clang-format
View file @
8c16125d
...
@@ -44,3 +44,10 @@ SpacesInSquareBrackets: false
...
@@ -44,3 +44,10 @@ SpacesInSquareBrackets: false
SortIncludes: false
SortIncludes: false
ReflowComments: true
ReflowComments: true
IncludeCategories:
- Regex: '^".*'
Priority: 3
- Regex: '^<.*'
Priority: 2
SortIncludes: true
src/element_type.cpp
View file @
8c16125d
...
@@ -27,7 +27,10 @@ const ngraph::ElementType element_type_uint64_t = ngraph::ElementType(64, false,
...
@@ -27,7 +27,10 @@ const ngraph::ElementType element_type_uint64_t = ngraph::ElementType(64, false,
std
::
map
<
std
::
string
,
ngraph
::
ElementType
>
ngraph
::
ElementType
::
m_element_list
;
std
::
map
<
std
::
string
,
ngraph
::
ElementType
>
ngraph
::
ElementType
::
m_element_list
;
ngraph
::
ElementType
::
ElementType
(
size_t
bitwidth
,
bool
is_float
,
bool
is_signed
,
const
std
::
string
&
cname
)
ngraph
::
ElementType
::
ElementType
(
size_t
bitwidth
,
bool
is_float
,
bool
is_signed
,
const
std
::
string
&
cname
)
:
m_bitwidth
{
bitwidth
}
:
m_bitwidth
{
bitwidth
}
,
m_is_float
{
is_float
}
,
m_is_float
{
is_float
}
,
m_is_signed
{
is_signed
}
,
m_is_signed
{
is_signed
}
...
...
src/element_type.hpp
View file @
8c16125d
...
@@ -18,8 +18,8 @@
...
@@ -18,8 +18,8 @@
#pragma once
#pragma once
#include <string>
#include <map>
#include <map>
#include <string>
namespace
ngraph
namespace
ngraph
{
{
...
...
src/log.cpp
View file @
8c16125d
...
@@ -14,12 +14,12 @@
...
@@ -14,12 +14,12 @@
*/
*/
#include <chrono>
#include <chrono>
#include <condition_variable>
#include <ctime>
#include <iomanip>
#include <iomanip>
#include <iostream>
#include <iostream>
#include <ctime>
#include <thread>
#include <mutex>
#include <mutex>
#include <
condition_variable
>
#include <
thread
>
#include "log.hpp"
#include "log.hpp"
...
...
src/log.hpp
View file @
8c16125d
...
@@ -15,9 +15,9 @@
...
@@ -15,9 +15,9 @@
#pragma once
#pragma once
#include <deque>
#include <sstream>
#include <sstream>
#include <stdexcept>
#include <stdexcept>
#include <deque>
namespace
nervana
namespace
nervana
{
{
...
...
src/names.cpp
View file @
8c16125d
src/names.hpp
View file @
8c16125d
...
@@ -14,40 +14,39 @@
...
@@ -14,40 +14,39 @@
#pragma once
#pragma once
#include <string>
#include <map>
#include <map>
#include <string>
namespace
ngraph
namespace
ngraph
{
{
//================================================================================================
//================================================================================================
// NameableValue
// NameableValue
// An Axis labels a dimension of a tensor. The op-graph uses
// An Axis labels a dimension of a tensor. The op-graph uses
// the identity of Axis objects to pair and specify dimensions in
// the identity of Axis objects to pair and specify dimensions in
// symbolic expressions. This system has several advantages over
// symbolic expressions. This system has several advantages over
// using the length and position of the axis as in other frameworks:
// using the length and position of the axis as in other frameworks:
//
//
// 1) Convenience. The dimensions of tensors, which may be nested
// 1) Convenience. The dimensions of tensors, which may be nested
// deep in a computation graph, can be specified without having to
// deep in a computation graph, can be specified without having to
// calculate their lengths.
// calculate their lengths.
//
//
// 2) Safety. Axis labels are analogous to types in general-purpose
// 2) Safety. Axis labels are analogous to types in general-purpose
// programming languages, allowing objects to interact only when
// programming languages, allowing objects to interact only when
// they are permitted to do so in advance. In symbolic computation,
// they are permitted to do so in advance. In symbolic computation,
// this prevents interference between axes that happen to have the
// this prevents interference between axes that happen to have the
// same lengths but are logically distinct, e.g. if the number of
// same lengths but are logically distinct, e.g. if the number of
// training examples and the number of input features are both 50.
// training examples and the number of input features are both 50.
//
//
// TODO: Please add to the list...
// TODO: Please add to the list...
//
//
// Arguments:
// Arguments:
// length: The length of the axis.
// length: The length of the axis.
// batch: Whether the axis is a batch axis.
// batch: Whether the axis is a batch axis.
// recurrent: Whether the axis is a recurrent axis.
// recurrent: Whether the axis is a recurrent axis.
//================================================================================================
//================================================================================================
class
NameableValue
class
NameableValue
{
{
public
:
public
:
//!-----------------------------------------------------------------------------------
//!-----------------------------------------------------------------------------------
//! NameableValue
//! NameableValue
//! An object that can be named.
//! An object that can be named.
...
@@ -103,7 +102,6 @@ public:
...
@@ -103,7 +102,6 @@ public:
std
::
string
m_graph_label
;
std
::
string
m_graph_label
;
std
::
string
m_short_name
;
std
::
string
m_short_name
;
std
::
string
m_doc_string
;
std
::
string
m_doc_string
;
};
};
}
// end namespace ngraph
}
// end namespace ngraph
src/strides.cpp
View file @
8c16125d
#include <iostream>
#include <algorithm>
#include <algorithm>
#include <iostream>
#include "strides.hpp"
#include "strides.hpp"
#include "util.hpp"
#include "util.hpp"
...
...
src/strides.hpp
View file @
8c16125d
#pragma once
#pragma once
#include <cstdio>
#include <cstdio>
#include <vector>
#include <initializer_list>
#include <initializer_list>
#include <vector>
#include "element_type.hpp"
#include "element_type.hpp"
#include "tree.hpp"
#include "tree.hpp"
...
@@ -27,7 +27,6 @@ public:
...
@@ -27,7 +27,6 @@ public:
ElementType
et
=
element_type_float
);
ElementType
et
=
element_type_float
);
const
ElementType
&
get_type
()
const
{
return
m_element_type
;
}
const
ElementType
&
get_type
()
const
{
return
m_element_type
;
}
tensor_stride
full_strides
()
const
;
tensor_stride
full_strides
()
const
;
tensor_stride
strides
()
const
;
tensor_stride
strides
()
const
;
tensor_size
sizes
()
const
;
tensor_size
sizes
()
const
;
...
@@ -53,7 +52,6 @@ class ngraph::tensor_stride
...
@@ -53,7 +52,6 @@ class ngraph::tensor_stride
public
:
public
:
tensor_stride
();
tensor_stride
();
const
ElementType
&
get_type
()
const
{
return
m_element_type
;
}
const
ElementType
&
get_type
()
const
{
return
m_element_type
;
}
tensor_stride
full_strides
()
const
;
tensor_stride
full_strides
()
const
;
tensor_stride
strides
()
const
;
tensor_stride
strides
()
const
;
...
...
src/transformers/axes.cpp
View file @
8c16125d
...
@@ -12,10 +12,10 @@
...
@@ -12,10 +12,10 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
// ----------------------------------------------------------------------------
#include <cassert>
#include <cmath>
#include <iostream>
#include <iostream>
#include <sstream>
#include <sstream>
#include <cmath>
#include <cassert>
#include "axes.hpp"
#include "axes.hpp"
#include "util.hpp"
#include "util.hpp"
...
...
src/transformers/axes.hpp
View file @
8c16125d
...
@@ -14,130 +14,130 @@
...
@@ -14,130 +14,130 @@
#pragma once
#pragma once
#include <vector>
#include <string>
#include <memory>
#include <limits>
#include <initializer_list>
#include <initializer_list>
#include <limits>
#include <memory>
#include <set>
#include <set>
#include <string>
#include <vector>
#include "uuid.hpp"
#include "element_type.hpp"
#include "element_type.hpp"
#include "names.hpp"
#include "names.hpp"
#include "util.hpp"
#include "strides.hpp"
#include "strides.hpp"
#include "util.hpp"
#include "uuid.hpp"
#include "uuid.hpp"
#include "uuid.hpp"
namespace
ngraph
namespace
ngraph
{
{
class
Axes
;
class
Axes
;
class
Axis
;
class
Axis
;
class
FlattenedAxis
;
class
FlattenedAxis
;
class
TensorDescription
;
class
TensorDescription
;
class
Op
;
class
Op
;
using
op_ptr
=
std
::
shared_ptr
<
Op
>
;
using
op_ptr
=
std
::
shared_ptr
<
Op
>
;
using
tensor_description_ptr
=
std
::
shared_ptr
<
TensorDescription
>
;
using
tensor_description_ptr
=
std
::
shared_ptr
<
TensorDescription
>
;
using
axes_key_t
=
size_t
;
using
axes_key_t
=
size_t
;
class
slice
class
slice
{
{
public
:
public
:
slice
(
int64_t
start
=
-
1
,
int64_t
stop
=
-
1
,
int64_t
step
=
1
);
slice
(
int64_t
start
=
-
1
,
int64_t
stop
=
-
1
,
int64_t
step
=
1
);
size_t
sliced_length
(
size_t
length
)
const
;
size_t
sliced_length
(
size_t
length
)
const
;
private
:
private
:
size_t
m_start
;
size_t
m_start
;
size_t
m_stop
;
size_t
m_stop
;
int64_t
m_step
;
int64_t
m_step
;
};
};
//-----------------------------------------------------------------------------------------------
//-----------------------------------------------------------------------------------------------
// default_dtype
// default_dtype
//-----------------------------------------------------------------------------------------------
//-----------------------------------------------------------------------------------------------
// def default_dtype(dtype=None):
// def default_dtype(dtype=None):
// if dtype is None:
// if dtype is None:
// dtype = np.dtype(np.float32)
// dtype = np.dtype(np.float32)
// elif not isinstance(dtype, Flex) and not isinstance(dtype, np.dtype):
// elif not isinstance(dtype, Flex) and not isinstance(dtype, np.dtype):
// try:
// try:
// dtype = np.dtype(dtype)
// dtype = np.dtype(dtype)
// except TypeError:
// except TypeError:
// raise TypeError("Could not cast {} to np.dtype".format(dtype))
// raise TypeError("Could not cast {} to np.dtype".format(dtype))
// return dtype
// return dtype
//-----------------------------------------------------------------------------------------------
//-----------------------------------------------------------------------------------------------
// default_int_dtype
// default_int_dtype
//-----------------------------------------------------------------------------------------------
//-----------------------------------------------------------------------------------------------
// def default_int_dtype(dtype=None):
// def default_int_dtype(dtype=None):
// if dtype is None:
// if dtype is None:
// dtype = np.dtype(np.int32)
// dtype = np.dtype(np.int32)
// elif not isinstance(dtype, Flex) and not isinstance(dtype, np.dtype):
// elif not isinstance(dtype, Flex) and not isinstance(dtype, np.dtype):
// try:
// try:
// dtype = np.dtype(dtype)
// dtype = np.dtype(dtype)
// except TypeError:
// except TypeError:
// raise TypeError("Could not cast {} to np.dtype".format(dtype))
// raise TypeError("Could not cast {} to np.dtype".format(dtype))
// return dtype
// return dtype
//================================================================================================
//================================================================================================
// make_axis
// make_axis
// Returns a new Axis.
// Returns a new Axis.
//
//
// Args:
// Args:
// length (int, optional): Length of the axis.
// length (int, optional): Length of the axis.
// name (String, optional): Name of the axis.
// name (String, optional): Name of the axis.
// batch (bool, optional): This is a batch axis. Defaults to False.
// batch (bool, optional): This is a batch axis. Defaults to False.
// recurrent (bool, optional): This is a recurrent axis. Defaults to False.
// recurrent (bool, optional): This is a recurrent axis. Defaults to False.
// docstring (String, optional): A docstring for the axis.
// docstring (String, optional): A docstring for the axis.
//
//
// Returns:
// Returns:
// Axis: A new Axis.
// Axis: A new Axis.
//================================================================================================
//================================================================================================
Axis
make_axis
(
size_t
length
,
Axis
make_axis
(
size_t
length
,
const
std
::
string
&
name
=
""
,
const
std
::
string
&
name
=
""
,
bool
batch
=
false
,
bool
batch
=
false
,
bool
recurrent
=
false
);
bool
recurrent
=
false
);
//================================================================================================
//================================================================================================
// make_axes
// make_axes
// Makes an Axes object.
// Makes an Axes object.
//
//
// Args:
// Args:
// axes: A list of Axis.
// axes: A list of Axis.
//
//
// Returns:
// Returns:
// Axes: An Axes.
// Axes: An Axes.
//================================================================================================
//================================================================================================
Axes
make_axes
(
const
std
::
vector
<
Axis
>&
);
Axes
make_axes
(
const
std
::
vector
<
Axis
>&
);
//================================================================================================
//================================================================================================
// Axis
// Axis
// An Axis labels a dimension of a tensor. The op-graph uses
// An Axis labels a dimension of a tensor. The op-graph uses
// the identity of Axis objects to pair and specify dimensions in
// the identity of Axis objects to pair and specify dimensions in
// symbolic expressions. This system has several advantages over
// symbolic expressions. This system has several advantages over
// using the length and position of the axis as in other frameworks:
// using the length and position of the axis as in other frameworks:
//
//
// 1) Convenience. The dimensions of tensors, which may be nested
// 1) Convenience. The dimensions of tensors, which may be nested
// deep in a computation graph, can be specified without having to
// deep in a computation graph, can be specified without having to
// calculate their lengths.
// calculate their lengths.
//
//
// 2) Safety. Axis labels are analogous to types in general-purpose
// 2) Safety. Axis labels are analogous to types in general-purpose
// programming languages, allowing objects to interact only when
// programming languages, allowing objects to interact only when
// they are permitted to do so in advance. In symbolic computation,
// they are permitted to do so in advance. In symbolic computation,
// this prevents interference between axes that happen to have the
// this prevents interference between axes that happen to have the
// same lengths but are logically distinct, e.g. if the number of
// same lengths but are logically distinct, e.g. if the number of
// training examples and the number of input features are both 50.
// training examples and the number of input features are both 50.
//
//
// TODO: Please add to the list...
// TODO: Please add to the list...
//
//
// Arguments:
// Arguments:
// length: The length of the axis.
// length: The length of the axis.
// batch: Whether the axis is a batch axis.
// batch: Whether the axis is a batch axis.
// recurrent: Whether the axis is a recurrent axis.
// recurrent: Whether the axis is a recurrent axis.
//================================================================================================
//================================================================================================
class
Axis
class
Axis
{
{
public
:
public
:
Axis
&
operator
+
(
const
Axis
&
rhs
);
Axis
&
operator
+
(
const
Axis
&
rhs
);
Axis
&
operator
-
(
const
Axis
&
rhs
);
Axis
&
operator
-
(
const
Axis
&
rhs
);
...
@@ -145,7 +145,6 @@ public:
...
@@ -145,7 +145,6 @@ public:
Axis
(
size_t
length
,
const
std
::
string
&
new_name
);
Axis
(
size_t
length
,
const
std
::
string
&
new_name
);
virtual
~
Axis
()
{}
virtual
~
Axis
()
{}
void
named
(
const
std
::
string
&
new_name
);
void
named
(
const
std
::
string
&
new_name
);
//!-----------------------------------------------------------------------------------
//!-----------------------------------------------------------------------------------
...
@@ -232,99 +231,99 @@ public:
...
@@ -232,99 +231,99 @@ public:
uuid_type
uuid
;
uuid_type
uuid
;
size_t
__length
;
size_t
__length
;
static
size_t
__name_counter
;
static
size_t
__name_counter
;
};
};
//-----------------------------------------------------------------------------------------------
//-----------------------------------------------------------------------------------------------
// _sliced_length
// _sliced_length
//-----------------------------------------------------------------------------------------------
//-----------------------------------------------------------------------------------------------
// def _sliced_length(s, incoming_length):
// def _sliced_length(s, incoming_length):
// start, stop, step = s.indices(incoming_length)
// start, stop, step = s.indices(incoming_length)
// # max with 0 so we dont ever return a negative length. This
// # max with 0 so we dont ever return a negative length. This
// # matches how python handles it internally. Raising an exception
// # matches how python handles it internally. Raising an exception
// # might also be reasonable.
// # might also be reasonable.
// if step == 1:
// if step == 1:
// return max(stop - start, 0)
// return max(stop - start, 0)
// elif step == -1:
// elif step == -1:
// return max(start - stop, 0)
// return max(start - stop, 0)
// else:
// else:
// _validate_slice(s)
// _validate_slice(s)
//-----------------------------------------------------------------------------------------------
//-----------------------------------------------------------------------------------------------
// _validate_slice
// _validate_slice
//-----------------------------------------------------------------------------------------------
//-----------------------------------------------------------------------------------------------
// def _validate_slice(s):
// def _validate_slice(s):
// if s.step not in (-1, 1, None):
// if s.step not in (-1, 1, None):
// raise ValueError((
// raise ValueError((
// 'SlicedAxis cant currently handle a step size other '
// 'SlicedAxis cant currently handle a step size other '
// 'than -1, 1 or None. Was given {step} in slice {slice}'
// 'than -1, 1 or None. Was given {step} in slice {slice}'
// ).format(
// ).format(
// step=s.step,
// step=s.step,
// slice=s,
// slice=s,
// ))
// ))
//-----------------------------------------------------------------------------------------------
//-----------------------------------------------------------------------------------------------
// slice_axis
// slice_axis
// Slice an axis, return complete new axis
// Slice an axis, return complete new axis
// TODO: deprecate this after the axis refactoring
// TODO: deprecate this after the axis refactoring
//
//
// Arguments:
// Arguments:
// axis: the axis to be sliced
// axis: the axis to be sliced
// s: slice
// s: slice
//
//
// Returns:
// Returns:
// Axis instance, the new sliced axis
// Axis instance, the new sliced axis
//-----------------------------------------------------------------------------------------------
//-----------------------------------------------------------------------------------------------
// def slice_axis(axis, s):
// def slice_axis(axis, s):
Axis
slice_axis
(
const
Axis
&
axis
,
const
slice
&
s
);
Axis
slice_axis
(
const
Axis
&
axis
,
const
slice
&
s
);
//-----------------------------------------------------------------------------------------------
//-----------------------------------------------------------------------------------------------
// duplicates
// duplicates
// Returns a list of Axis objects which have duplicate names in arr
// Returns a list of Axis objects which have duplicate names in arr
//
//
// Arguments:
// Arguments:
// arr: The iterable of Axis objects to check for duplicates in.
// arr: The iterable of Axis objects to check for duplicates in.
//
//
// Returns:
// Returns:
// list of Axis: duplicate Axis found in arr
// list of Axis: duplicate Axis found in arr
//-----------------------------------------------------------------------------------------------
//-----------------------------------------------------------------------------------------------
std
::
vector
<
std
::
string
>
duplicates
(
const
std
::
vector
<
Axis
>&
ax
);
std
::
vector
<
std
::
string
>
duplicates
(
const
std
::
vector
<
Axis
>&
ax
);
//-----------------------------------------------------------------------------------------------
//-----------------------------------------------------------------------------------------------
// with_args_as_axes
// with_args_as_axes
// A decorator to cast arguments to axes.
// A decorator to cast arguments to axes.
//
//
// Arguments:
// Arguments:
// f: The function to be decorated.
// f: The function to be decorated.
//
//
// Returns:
// Returns:
// The decorated function.
// The decorated function.
//-----------------------------------------------------------------------------------------------
//-----------------------------------------------------------------------------------------------
// def with_args_as_axes(f):
// def with_args_as_axes(f):
// @wraps(f)
// @wraps(f)
// def wrapper(*args):
// def wrapper(*args):
// """
// """
// The decorated function. Performs the conversion
// The decorated function. Performs the conversion
// to Axes.
// to Axes.
// Arguments:
// Arguments:
// *args: Arguments intended for the original function.
// *args: Arguments intended for the original function.
// Returns:
// Returns:
// Return value of the original function.
// Return value of the original function.
// """
// """
// args = [Axes(arg) for arg in args]
// args = [Axes(arg) for arg in args]
// return f(*args)
// return f(*args)
// return wrapper
// return wrapper
//================================================================================================
//================================================================================================
// Axes
// Axes
// An Axes is a tuple of Axis objects used as a label for a tensor's
// An Axes is a tuple of Axis objects used as a label for a tensor's
// dimensions.
// dimensions.
//================================================================================================
//================================================================================================
class
Axes
class
Axes
{
{
public
:
public
:
std
::
vector
<
Axis
>
axes
;
std
::
vector
<
Axis
>
axes
;
uuid_type
uuid
;
uuid_type
uuid
;
...
@@ -706,47 +705,47 @@ public:
...
@@ -706,47 +705,47 @@ public:
std
::
vector
<
Axis
>
convert
(
const
Axes
&
ax
);
std
::
vector
<
Axis
>
convert
(
const
Axes
&
ax
);
std
::
vector
<
Axis
>
convert
(
const
std
::
vector
<
Axes
>&
ax
);
std
::
vector
<
Axis
>
convert
(
const
std
::
vector
<
Axes
>&
ax
);
private
:
private
:
void
check_duplicates
();
void
check_duplicates
();
};
};
//================================================================================================
// DuplicateAxisNames
//================================================================================================
// class DuplicateAxisNames(ValueError):
// def __init__(self, message, duplicate_axis_names):
// super(DuplicateAxisNames, self).__init__(message)
// self.duplicate_axis_names = duplicate_axis_names
//================================================================================================
// IncompatibleAxesError
//================================================================================================
// class IncompatibleAxesError(ValueError):
// pass
//================================================================================================
// UnmatchedAxesError
//================================================================================================
// class UnmatchedAxesError(IncompatibleAxesError):
// pass
//================================================================================================
//================================================================================================
// AxesMap
// DuplicateAxisNames
// AxesMap provides a way to define a axis name mapping: {Axis.name: Axis.name} and
//================================================================================================
// then apply this mapping to an Axes and get new Axes out.
//
// class DuplicateAxisNames(ValueError):
// Right now AxesMap is implemented as immutible because I didn't want to deal with
// def __init__(self, message, duplicate_axis_names):
// enforcing _assert_valid_axes_map on every method which mutates a dict and I didn't
// super(DuplicateAxisNames, self).__init__(message)
// need a mutable datastructure anyway. Feel free to make it mutable and add in
// invariant enforcement.
// self.duplicate_axis_names = duplicate_axis_names
//================================================================================================
class
AxesMap
:
public
std
::
map
<
std
::
string
,
std
::
string
>
//================================================================================================
{
// IncompatibleAxesError
public
:
//================================================================================================
// class IncompatibleAxesError(ValueError):
// pass
//================================================================================================
// UnmatchedAxesError
//================================================================================================
// class UnmatchedAxesError(IncompatibleAxesError):
// pass
//================================================================================================
// AxesMap
// AxesMap provides a way to define a axis name mapping: {Axis.name: Axis.name} and
// then apply this mapping to an Axes and get new Axes out.
//
// Right now AxesMap is implemented as immutible because I didn't want to deal with
// enforcing _assert_valid_axes_map on every method which mutates a dict and I didn't
// need a mutable datastructure anyway. Feel free to make it mutable and add in
// invariant enforcement.
//================================================================================================
class
AxesMap
:
public
std
::
map
<
std
::
string
,
std
::
string
>
{
public
:
AxesMap
(
const
std
::
pair
<
std
::
string
,
std
::
string
>&
);
AxesMap
(
const
std
::
pair
<
std
::
string
,
std
::
string
>&
);
AxesMap
(
std
::
initializer_list
<
std
::
pair
<
std
::
string
,
std
::
string
>>
);
AxesMap
(
std
::
initializer_list
<
std
::
pair
<
std
::
string
,
std
::
string
>>
);
...
@@ -762,74 +761,70 @@ public:
...
@@ -762,74 +761,70 @@ public:
//--------------------------------------------------------------------------------------------
//--------------------------------------------------------------------------------------------
Axis
map_axis
(
const
Axis
&
old_axis
)
const
;
Axis
map_axis
(
const
Axis
&
old_axis
)
const
;
private
:
private
:
std
::
map
<
std
::
string
,
std
::
set
<
std
::
string
>>
duplicate_axis_names
();
std
::
map
<
std
::
string
,
std
::
set
<
std
::
string
>>
duplicate_axis_names
();
void
assert_valid_axes_map
();
void
assert_valid_axes_map
();
public
:
public
:
// def invert(self):
// def invert(self):
// return {v: k for k, v in self.items()}
// return {v: k for k, v in self.items()}
};
};
//-----------------------------------------------------------------------------------------------
//-----------------------------------------------------------------------------------------------
// _reduce_nested
// _reduce_nested
// Reduces a nested sequence by applying a function to each
// Reduces a nested sequence by applying a function to each
// of its elements and returns an aggregation.
// of its elements and returns an aggregation.
//
//
// Arguments:
// Arguments:
// elem: The object to be reduced, either a sequence
// elem: The object to be reduced, either a sequence
// or a singleton.
// or a singleton.
// agg: A variable holding information collected
// agg: A variable holding information collected
// as the sequence is collapsed.
// as the sequence is collapsed.
// func: A function to augment the aggregate by processing
// func: A function to augment the aggregate by processing
// a singleton. Should have the form func(agg, elem) -> agg
// a singleton. Should have the form func(agg, elem) -> agg
//
//
// Returns:
// Returns:
// agg: The final aggregate returned by the function.
// agg: The final aggregate returned by the function.
//-----------------------------------------------------------------------------------------------
//-----------------------------------------------------------------------------------------------
// def _reduce_nested(elem, agg, func):
// def _reduce_nested(elem, agg, func):
// if isinstance(elem, collections.Iterable):
// if isinstance(elem, collections.Iterable):
// for sub in elem:
// for sub in elem:
// agg = _reduce_nested(sub, agg, func)
// agg = _reduce_nested(sub, agg, func)
// return agg
// return agg
// else:
// else:
// return func(agg, elem)
// return func(agg, elem)
//================================================================================================
//================================================================================================
// FlattenedAxis
// FlattenedAxis
// A FlattenedAxis has length which is the product of the lengths of all
// A FlattenedAxis has length which is the product of the lengths of all
// Axis in the axes. The original Axes object is stored so that we can later
// Axis in the axes. The original Axes object is stored so that we can later
// unflatten this Axis back to its original component Axis.
// unflatten this Axis back to its original component Axis.
//
//
// Notes: since we allows Axis to have duplicated names globally, NameableValue
// Notes: since we allows Axis to have duplicated names globally, NameableValue
// is not used here.
// is not used here.
//================================================================================================
//================================================================================================
class
FlattenedAxis
:
public
Axis
class
FlattenedAxis
:
public
Axis
{
{
public
:
public
:
FlattenedAxis
(
const
std
::
vector
<
Axis
>&
list
,
const
std
::
string
&
new_name
=
""
);
FlattenedAxis
(
const
std
::
vector
<
Axis
>&
list
,
const
std
::
string
&
new_name
=
""
);
virtual
~
FlattenedAxis
()
{}
virtual
~
FlattenedAxis
()
{}
//--------------------------------------------------------------------------------------------
//--------------------------------------------------------------------------------------------
// Returns:
// Returns:
// True is this is a FlattendAxis.
// True is this is a FlattendAxis.
//--------------------------------------------------------------------------------------------
//--------------------------------------------------------------------------------------------
bool
is_flattened
()
const
{
return
true
;
}
bool
is_flattened
()
const
{
return
true
;
}
//--------------------------------------------------------------------------------------------
//--------------------------------------------------------------------------------------------
// Returns:
// Returns:
// Whether this axes contains no collapsed axes.
// Whether this axes contains no collapsed axes.
//--------------------------------------------------------------------------------------------
//--------------------------------------------------------------------------------------------
bool
empty
()
const
{
return
axes
.
size
()
==
0
;
}
bool
empty
()
const
{
return
axes
.
size
()
==
0
;
}
//--------------------------------------------------------------------------------------------
//--------------------------------------------------------------------------------------------
// Returns:
// Returns:
// Whether this axes contains exactly one collapsed axes.
// Whether this axes contains exactly one collapsed axes.
//--------------------------------------------------------------------------------------------
//--------------------------------------------------------------------------------------------
bool
single
()
const
{
return
axes
.
size
()
==
0
;
}
bool
single
()
const
{
return
axes
.
size
()
==
0
;
}
bool
operator
==
(
const
Axis
&
other
)
const
;
bool
operator
==
(
const
Axis
&
other
)
const
;
// def __hash__(self):
// def __hash__(self):
...
@@ -841,96 +836,96 @@ public:
...
@@ -841,96 +836,96 @@ public:
// return 'FlattenedAxis(%s)' % ', '.join(repr(axis) for axis in self.axes)
// return 'FlattenedAxis(%s)' % ', '.join(repr(axis) for axis in self.axes)
std
::
vector
<
Axis
>
axes
;
std
::
vector
<
Axis
>
axes
;
};
};
//-----------------------------------------------------------------------------------------------
//-----------------------------------------------------------------------------------------------
// default_dtype
// default_dtype
// Reduces a nested tuple describing the strides of a tensor
// Reduces a nested tuple describing the strides of a tensor
// into a tuple giving the stride of each of its dimensions.
// into a tuple giving the stride of each of its dimensions.
//
//
// Arguments:
// Arguments:
// strides: The nested tuple.
// strides: The nested tuple.
//
//
// Returns:
// Returns:
// strides: The tuple of strides.
// strides: The tuple of strides.
//-----------------------------------------------------------------------------------------------
//-----------------------------------------------------------------------------------------------
// def reduce_strides(strides):
// def reduce_strides(strides):
// return tuple(int(_reduce_nested(elem, float('inf'), min))
// return tuple(int(_reduce_nested(elem, float('inf'), min))
// for elem in strides)
// for elem in strides)
//-----------------------------------------------------------------------------------------------
//-----------------------------------------------------------------------------------------------
// _make_stride
// _make_stride
// Generates a nested tuple that provides the striding information
// Generates a nested tuple that provides the striding information
// for an occurrence of axis. If the axis is a FlattenedAxis, the
// for an occurrence of axis. If the axis is a FlattenedAxis, the
// stride will be a tuple containing the strides of each collapsed
// stride will be a tuple containing the strides of each collapsed
// axis. Otherwise, the stride will be an integer.
// axis. Otherwise, the stride will be an integer.
//
//
// Arguments:
// Arguments:
// inner_size: The total size of all dimensions smaller than this
// inner_size: The total size of all dimensions smaller than this
// axis, i.e. all axes to the right of this one when they are
// axis, i.e. all axes to the right of this one when they are
// laid out in c-contiguous order.
// laid out in c-contiguous order.
// axis: The axis for which we are generating a stride.
// axis: The axis for which we are generating a stride.
// fsz: A nested tuple supplying the sizes of each dimension collapsed
// fsz: A nested tuple supplying the sizes of each dimension collapsed
// into the axis. The size may be larger than the length of the axis.
// into the axis. The size may be larger than the length of the axis.
//
//
// Returns:
// Returns:
// inner_size: The total size of this axis and all smaller dimensions.
// inner_size: The total size of this axis and all smaller dimensions.
// stride: The stride given to the axis.
// stride: The stride given to the axis.
//-----------------------------------------------------------------------------------------------
//-----------------------------------------------------------------------------------------------
// def _make_stride(inner_size, axis, fsz):
// def _make_stride(inner_size, axis, fsz):
// if axis.is_flattened:
// if axis.is_flattened:
// return _make_strides(inner_size, axis.axes, fsz)
// return _make_strides(inner_size, axis.axes, fsz)
// else:
// else:
// stride = inner_size
// stride = inner_size
// inner_size *= fsz
// inner_size *= fsz
// return inner_size, stride
// return inner_size, stride
//-----------------------------------------------------------------------------------------------
//-----------------------------------------------------------------------------------------------
// _make_strides
// _make_strides
// Generates a tuple of strides for a set of axes. See _make_stride
// Generates a tuple of strides for a set of axes. See _make_stride
// for a description of the stride given to each axis.
// for a description of the stride given to each axis.
//
//
// Arguments:
// Arguments:
// inner_size: The total size of all dimensions smaller than
// inner_size: The total size of all dimensions smaller than
// the axes.
// the axes.
// axes: The axes for which we are generating strides.
// axes: The axes for which we are generating strides.
// full_sizes: The size of each axis.
// full_sizes: The size of each axis.
//
//
// Returns:
// Returns:
// inner_size: The total size of these axes and all smaller dimensions.
// inner_size: The total size of these axes and all smaller dimensions.
// strides: The strides generated for the axes.
// strides: The strides generated for the axes.
//-----------------------------------------------------------------------------------------------
//-----------------------------------------------------------------------------------------------
// def _make_strides(inner_size, axes, full_sizes):
// def _make_strides(inner_size, axes, full_sizes):
// full_strides = []
// full_strides = []
// for axis, fsz in reversed(list(zip(axes, full_sizes))):
// for axis, fsz in reversed(list(zip(axes, full_sizes))):
// inner_size, stride = _make_stride(inner_size, axis, fsz)
// inner_size, stride = _make_stride(inner_size, axis, fsz)
// full_strides.append(stride)
// full_strides.append(stride)
// return inner_size, tuple(reversed(full_strides))
// return inner_size, tuple(reversed(full_strides))
//================================================================================================
//================================================================================================
// TensorDescription
// TensorDescription
// Description of a tensor that will be allocated in hardware.
// Description of a tensor that will be allocated in hardware.
//
//
// Names the tensor's dimensions with axes and holds pointers to the
// Names the tensor's dimensions with axes and holds pointers to the
// buffer allocated by the analysis and the backend tensor value
// buffer allocated by the analysis and the backend tensor value
// (e.g. a cpu or gpu tensor).
// (e.g. a cpu or gpu tensor).
//
//
// Arguments:
// Arguments:
// axes: Axes of the tensor.
// axes: Axes of the tensor.
// base: If a view, the viewed tensor's description.
// base: If a view, the viewed tensor's description.
// dtype: The type of the tensor.
// dtype: The type of the tensor.
// full_strides: The strides of each axis.
// full_strides: The strides of each axis.
// full_sizes: The allocated size of each axis (may be larger than the axis).
// full_sizes: The allocated size of each axis (may be larger than the axis).
// offset: An offset into the viewed tensor.
// offset: An offset into the viewed tensor.
// next_tensor_decription: In a reshape, tensor description of reshaped tensor.
// next_tensor_decription: In a reshape, tensor description of reshaped tensor.
// is_persistent: The tensor should be persistent, i.e. survive from computation to
// is_persistent: The tensor should be persistent, i.e. survive from computation to
// computation.
// computation.
// is_input: The device tensor can be written from the host.
// is_input: The device tensor can be written from the host.
// **kwargs: Additional args for related classes.
// **kwargs: Additional args for related classes.
//================================================================================================
//================================================================================================
class
TensorDescription
:
public
NameableValue
class
TensorDescription
:
public
NameableValue
{
{
public
:
public
:
//!-----------------------------------------------------------------------------------
//!-----------------------------------------------------------------------------------
//! constructor
//! constructor
//!-----------------------------------------------------------------------------------
//!-----------------------------------------------------------------------------------
...
@@ -1487,7 +1482,7 @@ public:
...
@@ -1487,7 +1482,7 @@ public:
ngraph
::
tensor_size
full_sizes
;
ngraph
::
tensor_size
full_sizes
;
ngraph
::
tensor_stride
full_strides
;
ngraph
::
tensor_stride
full_strides
;
tensor_description_ptr
next_tensor_description
;
tensor_description_ptr
next_tensor_description
;
};
};
}
// end of namespace ngraph
}
// end of namespace ngraph
...
...
src/transformers/exop.cpp
View file @
8c16125d
...
@@ -12,10 +12,10 @@
...
@@ -12,10 +12,10 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
// ----------------------------------------------------------------------------
#include <cmath>
#include <exception>
#include <memory>
#include <memory>
#include <sstream>
#include <sstream>
#include <exception>
#include <cmath>
#include "exop.hpp"
#include "exop.hpp"
#include "op_graph.hpp"
#include "op_graph.hpp"
...
...
src/transformers/exop.hpp
View file @
8c16125d
...
@@ -15,67 +15,66 @@
...
@@ -15,67 +15,66 @@
#pragma once
#pragma once
#include <iostream>
#include <iostream>
#include <
string
>
#include <
list
>
#include <map>
#include <map>
#include <memory>
#include <memory>
#include <vector>
#include <sstream>
#include <set>
#include <set>
#include <list>
#include <sstream>
#include <string>
#include <vector>
#include "axes.hpp"
#include "mock.hpp"
#include "mock.hpp"
#include "op_graph.hpp"
#include "op_graph.hpp"
#include "axes.hpp"
namespace
ngraph
namespace
ngraph
{
{
// forward declaration. This will hopefully go away
// forward declaration. This will hopefully go away
class
ExecutionGraph
;
class
ExecutionGraph
;
class
TensorDescription
;
class
TensorDescription
;
class
InputDecl
;
class
InputDecl
;
class
OutputDecl
;
class
OutputDecl
;
class
TensorDecl
;
class
TensorDecl
;
class
TensorViewDecl
;
class
TensorViewDecl
;
class
ExOp
;
class
ExOp
;
class
Op
;
class
Op
;
class
ComputationDecl
;
class
ComputationDecl
;
class
ExOpBlock
;
class
ExOpBlock
;
class
ExecutionState
;
class
ExecutionState
;
using
output_decl_ptr
=
std
::
shared_ptr
<
OutputDecl
>
;
using
output_decl_ptr
=
std
::
shared_ptr
<
OutputDecl
>
;
using
input_decl_ptr
=
std
::
shared_ptr
<
InputDecl
>
;
using
input_decl_ptr
=
std
::
shared_ptr
<
InputDecl
>
;
using
tensor_decl_ptr
=
std
::
shared_ptr
<
TensorDecl
>
;
using
tensor_decl_ptr
=
std
::
shared_ptr
<
TensorDecl
>
;
using
tensor_view_decl_ptr
=
std
::
shared_ptr
<
TensorViewDecl
>
;
using
tensor_view_decl_ptr
=
std
::
shared_ptr
<
TensorViewDecl
>
;
using
exop_ptr
=
std
::
shared_ptr
<
ExOp
>
;
using
exop_ptr
=
std
::
shared_ptr
<
ExOp
>
;
using
computation_decl_ptr
=
std
::
shared_ptr
<
ComputationDecl
>
;
using
computation_decl_ptr
=
std
::
shared_ptr
<
ComputationDecl
>
;
using
execution_graph_ptr
=
std
::
shared_ptr
<
ExecutionGraph
>
;
using
execution_graph_ptr
=
std
::
shared_ptr
<
ExecutionGraph
>
;
using
exop_block_ptr
=
std
::
shared_ptr
<
ExOpBlock
>
;
using
exop_block_ptr
=
std
::
shared_ptr
<
ExOpBlock
>
;
using
tensor_ptr
=
std
::
shared_ptr
<
TensorInterface
>
;
using
tensor_ptr
=
std
::
shared_ptr
<
TensorInterface
>
;
using
transformer_ptr
=
std
::
shared_ptr
<
Transformer
>
;
using
transformer_ptr
=
std
::
shared_ptr
<
Transformer
>
;
using
execution_state_ptr
=
std
::
shared_ptr
<
ExecutionState
>
;
using
execution_state_ptr
=
std
::
shared_ptr
<
ExecutionState
>
;
//================================================================================================
//================================================================================================
// OutputDecl
// OutputDecl
// One value computed by an exop
// One value computed by an exop
//
//
// Arguments:
// Arguments:
// exop: The exop.
// exop: The exop.
// pos: The position of the value, defaults to 0.
// pos: The position of the value, defaults to 0.
// tensor_description: Tensor description of the value.
// tensor_description: Tensor description of the value.
// write_view: The tensor view where the value is written.
// write_view: The tensor view where the value is written.
//
//
// Attributes:
// Attributes:
// exop: The exop.
// exop: The exop.
// pos: The position of the value.
// pos: The position of the value.
// tensor_description: Tensor description of the value.
// tensor_description: Tensor description of the value.
// write_view: The tensor view where the value is written.
// write_view: The tensor view where the value is written.
// value_users: Arguments using this value.
// value_users: Arguments using this value.
//================================================================================================
//================================================================================================
class
OutputDecl
class
OutputDecl
{
{
public
:
public
:
OutputDecl
(
const
ExOp
&
_exop
,
size_t
_pos
,
tensor_decl_ptr
,
tensor_description_ptr
);
OutputDecl
(
const
ExOp
&
_exop
,
size_t
_pos
,
tensor_decl_ptr
,
tensor_description_ptr
);
tensor_decl_ptr
tensor_decl
();
tensor_decl_ptr
tensor_decl
();
void
tensor_decl
(
tensor_decl_ptr
tensor_decl
);
void
tensor_decl
(
tensor_decl_ptr
tensor_decl
);
...
@@ -95,29 +94,29 @@ public:
...
@@ -95,29 +94,29 @@ public:
tensor_decl_ptr
__tensor
;
tensor_decl_ptr
__tensor
;
tensor_view_decl_ptr
__write_view
;
tensor_view_decl_ptr
__write_view
;
std
::
set
<
InputDecl
*>
value_users
;
std
::
set
<
InputDecl
*>
value_users
;
};
};
//================================================================================================
//================================================================================================
// InputDecl
// InputDecl
// An argument for an exop.
// An argument for an exop.
//
//
// Arguments:
// Arguments:
// exop: The exop.
// exop: The exop.
// pos: The position of the value, defaults to 0.
// pos: The position of the value, defaults to 0.
// tensor_description: Tensor description of the value.
// tensor_description: Tensor description of the value.
// read_view: The tensor view where the value is read from.
// read_view: The tensor view where the value is read from.
//
//
// Attributes:
// Attributes:
// exop: The exop.
// exop: The exop.
// pos: The position of the value.
// pos: The position of the value.
// tensor_description: Tensor description of the value.
// tensor_description: Tensor description of the value.
// read_view: The tensor view where the value is read from.
// read_view: The tensor view where the value is read from.
// value: Arguments supplying this value.
// value: Arguments supplying this value.
//================================================================================================
//================================================================================================
class
InputDecl
class
InputDecl
{
{
public
:
public
:
InputDecl
(
const
ExOp
&
_exop
,
InputDecl
(
const
ExOp
&
_exop
,
size_t
_pos
,
size_t
_pos
,
tensor_description_ptr
_tensor_description
,
tensor_description_ptr
_tensor_description
,
...
@@ -134,37 +133,37 @@ public:
...
@@ -134,37 +133,37 @@ public:
tensor_description_ptr
tensor_description
;
tensor_description_ptr
tensor_description
;
tensor_view_decl_ptr
read_view
;
tensor_view_decl_ptr
read_view
;
OutputDecl
*
m_value
;
OutputDecl
*
m_value
;
};
};
//================================================================================================
//================================================================================================
// ExecutionGraphElt
// ExecutionGraphElt
// An element of an exection graph.
// An element of an exection graph.
//
//
// Arguments:
// Arguments:
// execution_graph: The execution graph that indexes this exop.
// execution_graph: The execution graph that indexes this exop.
//
//
// Attributes:
// Attributes:
// execution_graph: The execution graph that indexes this exop.
// execution_graph: The execution graph that indexes this exop.
//================================================================================================
//================================================================================================
class
ExecutionGraphElt
class
ExecutionGraphElt
{
{
public
:
public
:
ExecutionGraphElt
(
ExecutionGraph
&
eg
)
ExecutionGraphElt
(
ExecutionGraph
&
eg
)
:
execution_graph
{
eg
}
:
execution_graph
{
eg
}
{
{
}
}
ExecutionGraph
&
execution_graph
;
ExecutionGraph
&
execution_graph
;
};
};
//================================================================================================
//================================================================================================
// ExOp
// ExOp
//================================================================================================
//================================================================================================
class
ExOp
:
public
ExecutionGraphElt
class
ExOp
:
public
ExecutionGraphElt
{
{
public
:
public
:
// An exop that indicates an op to be executed.
// An exop that indicates an op to be executed.
// The op might be different from what was originally found in the computation graph.
// The op might be different from what was originally found in the computation graph.
...
@@ -220,17 +219,18 @@ public:
...
@@ -220,17 +219,18 @@ public:
std
::
vector
<
tensor_decl_ptr
>
liveness_free_list
;
std
::
vector
<
tensor_decl_ptr
>
liveness_free_list
;
std
::
vector
<
tensor_decl_ptr
>
liveness_new_list
;
std
::
vector
<
tensor_decl_ptr
>
liveness_new_list
;
std
::
vector
<
InputDecl
>
args
;
std
::
vector
<
InputDecl
>
args
;
std
::
vector
<
InputDecl
*>
write_args
;
// TODO: Kludge until we have values with writers/readers
std
::
vector
<
InputDecl
*>
write_args
;
// TODO: Kludge until we have values with writers/readers
std
::
vector
<
OutputDecl
>
values
;
std
::
vector
<
OutputDecl
>
values
;
};
};
//================================================================================================
//================================================================================================
// TensorDecl
// TensorDecl
//================================================================================================
//================================================================================================
class
TensorDecl
:
public
ExecutionGraphElt
class
TensorDecl
:
public
ExecutionGraphElt
{
{
public
:
public
:
// Allocate for a tensor.
// Allocate for a tensor.
// Arguments:
// Arguments:
...
@@ -294,15 +294,15 @@ public:
...
@@ -294,15 +294,15 @@ public:
bool
is_compile_only
;
bool
is_compile_only
;
tensor_ptr
initial_value
;
tensor_ptr
initial_value
;
tensor_decl_ptr
source_tensor
;
tensor_decl_ptr
source_tensor
;
};
};
//================================================================================================
//================================================================================================
// ExOpBlock
// ExOpBlock
//================================================================================================
//================================================================================================
class
ExOpBlock
:
public
ExecutionGraphElt
class
ExOpBlock
:
public
ExecutionGraphElt
{
{
public
:
public
:
// Sequentially execute a list of exops.
// Sequentially execute a list of exops.
// Attributes:
// Attributes:
...
@@ -312,7 +312,8 @@ public:
...
@@ -312,7 +312,8 @@ public:
// root_set: Set of exops whose values are needed.
// root_set: Set of exops whose values are needed.
ExOpBlock
(
ComputationDecl
&
cgraph
);
ExOpBlock
(
ComputationDecl
&
cgraph
);
bool
is_exop_end_of_list
();
bool
is_exop_end_of_list
();
void
add_ops
(
std
::
initializer_list
<
computation_op_ptr
>
roots
,
exop_ptr
after_exop
=
nullptr
);
void
add_ops
(
std
::
initializer_list
<
computation_op_ptr
>
roots
,
exop_ptr
after_exop
=
nullptr
);
exop_ptr
add_op
(
op_ptr
op
,
exop_ptr
after_exop
);
exop_ptr
add_op
(
op_ptr
op
,
exop_ptr
after_exop
);
exop_ptr
add_exop
(
exop_ptr
exop
,
exop_ptr
after_exop
=
nullptr
);
exop_ptr
add_exop
(
exop_ptr
exop
,
exop_ptr
after_exop
=
nullptr
);
void
move_exop_to_after_exop
(
exop_ptr
exop
,
exop_ptr
after_exop
);
void
move_exop_to_after_exop
(
exop_ptr
exop
,
exop_ptr
after_exop
);
...
@@ -336,17 +337,16 @@ public:
...
@@ -336,17 +337,16 @@ public:
// replacement for next_exop, prev_exop
// replacement for next_exop, prev_exop
std
::
list
<
exop_ptr
>::
iterator
begin
()
{
return
op_list
.
begin
();
}
std
::
list
<
exop_ptr
>::
iterator
begin
()
{
return
op_list
.
begin
();
}
std
::
list
<
exop_ptr
>::
iterator
end
()
{
return
op_list
.
end
();
}
std
::
list
<
exop_ptr
>::
iterator
end
()
{
return
op_list
.
end
();
}
std
::
list
<
exop_ptr
>
op_list
;
std
::
list
<
exop_ptr
>
op_list
;
};
};
//================================================================================================
//================================================================================================
// TensorViewDecl
// TensorViewDecl
//================================================================================================
//================================================================================================
class
TensorViewDecl
:
public
ExecutionGraphElt
class
TensorViewDecl
:
public
ExecutionGraphElt
{
{
public
:
public
:
// Declare a view of a tensor.
// Declare a view of a tensor.
// Arguments:
// Arguments:
...
@@ -373,17 +373,17 @@ public:
...
@@ -373,17 +373,17 @@ public:
std
::
set
<
InputDecl
*>
readers
;
std
::
set
<
InputDecl
*>
readers
;
std
::
set
<
OutputDecl
*>
writers
;
std
::
set
<
OutputDecl
*>
writers
;
OutputDecl
*
value
;
OutputDecl
*
value
;
};
};
// static exop_ptr _default_default;
// static exop_ptr _default_default;
//================================================================================================
//================================================================================================
// ComputationDecl
// ComputationDecl
//================================================================================================
//================================================================================================
class
ComputationDecl
:
public
ExecutionGraphElt
class
ComputationDecl
:
public
ExecutionGraphElt
{
{
public
:
public
:
// One computation to be run.
// One computation to be run.
// Every computation has its own execution graph. Persistent tensors are shared
// Every computation has its own execution graph. Persistent tensors are shared
...
@@ -406,15 +406,15 @@ public:
...
@@ -406,15 +406,15 @@ public:
exop_block_ptr
exop_block
;
exop_block_ptr
exop_block
;
exop_ptr
returns
;
exop_ptr
returns
;
std
::
set
<
ExOp
*>
values
;
std
::
set
<
ExOp
*>
values
;
};
};
//================================================================================================
//================================================================================================
// ExecutionState
// ExecutionState
//================================================================================================
//================================================================================================
class
ExecutionState
class
ExecutionState
{
{
public
:
public
:
// Proxy for the state of a device.
// Proxy for the state of a device.
// Arguments:
// Arguments:
...
@@ -429,15 +429,15 @@ public:
...
@@ -429,15 +429,15 @@ public:
// persistent tensors
// persistent tensors
std
::
map
<
tensor_description_ptr
,
tensor_decl_ptr
>
__tensors_decls
;
std
::
map
<
tensor_description_ptr
,
tensor_decl_ptr
>
__tensors_decls
;
};
};
//================================================================================================
//================================================================================================
// ExecutionGraph
// ExecutionGraph
//================================================================================================
//================================================================================================
class
ExecutionGraph
class
ExecutionGraph
{
{
public
:
public
:
// Information for compiling a computation_op.
// Information for compiling a computation_op.
// Arguments:
// Arguments:
...
@@ -452,6 +452,6 @@ public:
...
@@ -452,6 +452,6 @@ public:
// temporary tensors
// temporary tensors
std
::
map
<
tensor_description_ptr
,
tensor_decl_ptr
>
tensor_decls
;
std
::
map
<
tensor_description_ptr
,
tensor_decl_ptr
>
tensor_decls
;
computation_decl_ptr
computation_decl
;
computation_decl_ptr
computation_decl
;
};
};
}
// end namespace ngraph
}
// end namespace ngraph
src/transformers/mock.hpp
View file @
8c16125d
...
@@ -14,49 +14,47 @@
...
@@ -14,49 +14,47 @@
#pragma once
#pragma once
#include <string>
#include <memory>
#include <map>
#include <map>
#include <vector>
#include <memory>
#include <type_traits>
#include <sstream>
#include <sstream>
#include <string>
#include <type_traits>
#include <vector>
#include "element_type.hpp"
#include "element_type.hpp"
namespace
ngraph
namespace
ngraph
{
{
class
ExecutionState
;
class
ExecutionState
;
class
Op
;
// class TensorDescription;
class
Op
;
class
ComputationOp
;
// class TensorDescription;
class
ComputationOp
;
using
computation_op_ptr
=
std
::
shared_ptr
<
ComputationOp
>
;
using
computation_op_ptr
=
std
::
shared_ptr
<
ComputationOp
>
;
using
op_ptr
=
std
::
shared_ptr
<
Op
>
;
using
op_ptr
=
std
::
shared_ptr
<
Op
>
;
using
scalar_t
=
float
;
using
scalar_t
=
float
;
//================================================================================================
//================================================================================================
// TensorInterface
// TensorInterface
//================================================================================================
//================================================================================================
class
TensorInterface
class
TensorInterface
{
{
public
:
public
:
virtual
~
TensorInterface
()
{}
virtual
~
TensorInterface
()
{}
virtual
const
ElementType
&
element_type
()
const
=
0
;
virtual
const
ElementType
&
element_type
()
const
=
0
;
virtual
std
::
string
value_string
()
const
=
0
;
virtual
std
::
string
value_string
()
const
=
0
;
};
};
//================================================================================================
//================================================================================================
// Tensor
// Tensor
//================================================================================================
//================================================================================================
template
<
typename
T
>
template
<
typename
T
>
class
Tensor
:
public
TensorInterface
class
Tensor
:
public
TensorInterface
{
{
public
:
public
:
Tensor
(
const
T
&
val
)
Tensor
(
const
T
&
val
)
:
m_value
{
val
}
:
m_value
{
val
}
,
m_element_type
{
element_type_float
}
,
m_element_type
{
element_type_float
}
...
@@ -64,9 +62,7 @@ public:
...
@@ -64,9 +62,7 @@ public:
}
}
virtual
~
Tensor
()
{}
virtual
~
Tensor
()
{}
const
ElementType
&
element_type
()
const
override
{
return
m_element_type
;
}
const
ElementType
&
element_type
()
const
override
{
return
m_element_type
;
}
std
::
string
value_string
()
const
override
std
::
string
value_string
()
const
override
{
{
std
::
string
rc
=
"WTF"
;
std
::
string
rc
=
"WTF"
;
...
@@ -79,104 +75,104 @@ public:
...
@@ -79,104 +75,104 @@ public:
return
rc
;
return
rc
;
}
}
private
:
private
:
T
m_value
;
T
m_value
;
ElementType
m_element_type
;
ElementType
m_element_type
;
};
};
//================================================================================================
//================================================================================================
// Transformer
// Transformer
//================================================================================================
//================================================================================================
class
Transformer
class
Transformer
{
{
public
:
public
:
virtual
~
Transformer
()
{}
virtual
~
Transformer
()
{}
virtual
ExecutionState
&
execution_state
()
=
0
;
virtual
ExecutionState
&
execution_state
()
=
0
;
};
};
//================================================================================================
//================================================================================================
// TensorDescription
// TensorDescription
//================================================================================================
//================================================================================================
// class TensorDescription
// class TensorDescription
// {
// {
// public:
// public:
// virtual ~TensorDescription();
// virtual ~TensorDescription();
// virtual axes_key_t axes_key() const = 0;
// virtual axes_key_t axes_key() const = 0;
// virtual std::string name() const = 0;
// virtual std::string name() const = 0;
// virtual std::vector<size_t> shape() const = 0;
// virtual std::vector<size_t> shape() const = 0;
// virtual std::shared_ptr<TensorDescription> base() = 0;
// virtual std::shared_ptr<TensorDescription> base() = 0;
// virtual ElementType element_type() const = 0;
// virtual ElementType element_type() const = 0;
// virtual size_t tensor_size() = 0;
// virtual size_t tensor_size() = 0;
// virtual bool is_persistent() = 0;
// virtual bool is_persistent() = 0;
// virtual bool is_input() = 0;
// virtual bool is_input() = 0;
// };
// };
//================================================================================================
//================================================================================================
// Op
// Op
//================================================================================================
//================================================================================================
// class Op
// class Op
// {
// {
// // Any operation that can be in an AST.
// // Any operation that can be in an AST.
// // Arguments:
// // Arguments:
// // args: Values used by this node.
// // args: Values used by this node.
// // const: The value of a constant Op, or None,
// // const: The value of a constant Op, or None,
// // constant (bool): The Op is constant. Default False.
// // constant (bool): The Op is constant. Default False.
// // forward: If not None, the node to use instead of this node.
// // forward: If not None, the node to use instead of this node.
// // metadata: String key value dictionary for frontend metadata.
// // metadata: String key value dictionary for frontend metadata.
// // kwargs: Args defined in related classes.
// // kwargs: Args defined in related classes.
// // Attributes:
// // Attributes:
// // const: The value of a constant.
// // const: The value of a constant.
// // constant (bool): The value is constant.
// // constant (bool): The value is constant.
// // control_deps (OrderedSet): Ops in addtion to args that must run before this op.
// // control_deps (OrderedSet): Ops in addtion to args that must run before this op.
// // persistent (bool): The value will be retained from computation to computation and
// // persistent (bool): The value will be retained from computation to computation and
// // not shared. Always True if reference is set.
// // not shared. Always True if reference is set.
// // metadata: Dictionary with of string keys and values used for attaching
// // metadata: Dictionary with of string keys and values used for attaching
// // arbitrary metadata to nodes.
// // arbitrary metadata to nodes.
// // trainable: The value is trainable.
// // trainable: The value is trainable.
// public:
// public:
// virtual ~Op() {}
// virtual ~Op() {}
// virtual std::string name() const = 0;
// virtual std::string name() const = 0;
// virtual tensor_description_ptr tensor_description() = 0;
// virtual tensor_description_ptr tensor_description() = 0;
// virtual op_ptr tensor() = 0;
// virtual op_ptr tensor() = 0;
// virtual bool is_tensor_op() = 0;
// virtual bool is_tensor_op() = 0;
// virtual bool is_state_op() const = 0;
// virtual bool is_state_op() const = 0;
// virtual bool is_sequencing_op() const = 0;
// virtual bool is_sequencing_op() const = 0;
// virtual op_ptr effective_tensor_op() = 0;
// virtual op_ptr effective_tensor_op() = 0;
// virtual const std::vector<op_ptr>& all_deps() const = 0;
// virtual const std::vector<op_ptr>& all_deps() const = 0;
// // ops
// // ops
// // TODO support multiple types
// // TODO support multiple types
// static op_ptr constant(float value)
// static op_ptr constant(float value)
// {
// {
// op_ptr = make_shared<LiteralScalarOp>(value);
// op_ptr = make_shared<LiteralScalarOp>(value);
// }
// }
// };
// };
//================================================================================================
//================================================================================================
// TensorOp
// TensorOp
//================================================================================================
//================================================================================================
// class TensorOp : public Op
// class TensorOp : public Op
// {
// {
// public:
// public:
// std::string name() const override { return "TensorOp"; }
// std::string name() const override { return "TensorOp"; }
// tensor_description_ptr tensor_description() override { return nullptr; }
// tensor_description_ptr tensor_description() override { return nullptr; }
// op_ptr tensor() override { return nullptr; }
// op_ptr tensor() override { return nullptr; }
// bool is_tensor_op() override { return true; }
// bool is_tensor_op() override { return true; }
// bool is_state_op() const override { return false; }
// bool is_state_op() const override { return false; }
// op_ptr effective_tensor_op() override { return nullptr; }
// op_ptr effective_tensor_op() override { return nullptr; }
// const std::vector<op_ptr>& all_deps() const override { return m_all_deps; }
// const std::vector<op_ptr>& all_deps() const override { return m_all_deps; }
// private:
// private:
// std::vector<op_ptr> m_all_deps;
// std::vector<op_ptr> m_all_deps;
// };
// };
}
// end of namespace ngraph
}
// end of namespace ngraph
src/transformers/mock_transformer.hpp
View file @
8c16125d
...
@@ -14,24 +14,21 @@
...
@@ -14,24 +14,21 @@
#pragma once
#pragma once
#include "mock.hpp"
#include "exop.hpp"
#include "exop.hpp"
#include "mock.hpp"
namespace
ngraph
namespace
ngraph
{
{
//================================================================================================
//================================================================================================
// CpuTransformer
// CpuTransformer
//================================================================================================
//================================================================================================
class
CpuTransformer
:
public
Transformer
class
CpuTransformer
:
public
Transformer
{
{
public
:
public
:
virtual
~
CpuTransformer
()
{}
virtual
~
CpuTransformer
()
{}
ExecutionState
&
execution_state
()
override
{
return
m_execution_state
;
}
ExecutionState
&
execution_state
()
override
{
return
m_execution_state
;
}
private
:
private
:
ExecutionState
m_execution_state
;
ExecutionState
m_execution_state
;
};
};
}
// end namespace ngraph
}
// end namespace ngraph
src/transformers/ndarray.hpp
View file @
8c16125d
...
@@ -14,8 +14,8 @@
...
@@ -14,8 +14,8 @@
#pragma once
#pragma once
#include <vector>
#include <memory>
#include <memory>
#include <vector>
#include "element_type.hpp"
#include "element_type.hpp"
#include "strides.hpp"
#include "strides.hpp"
...
...
src/transformers/op_graph.cpp
View file @
8c16125d
...
@@ -14,8 +14,8 @@
...
@@ -14,8 +14,8 @@
#include <sstream>
#include <sstream>
#include "op_graph.hpp"
#include "axes.hpp"
#include "axes.hpp"
#include "op_graph.hpp"
#include "util.hpp"
#include "util.hpp"
using
namespace
ngraph
;
using
namespace
ngraph
;
...
@@ -2794,7 +2794,9 @@ ElementWiseOp::ElementWiseOp()
...
@@ -2794,7 +2794,9 @@ ElementWiseOp::ElementWiseOp()
{
{
}
}
void
ElementWiseOp
::
ElementWiseOp_init
(
std
::
vector
<
op_ptr
>
,
Axes
)
{}
void
ElementWiseOp
::
ElementWiseOp_init
(
std
::
vector
<
op_ptr
>
,
Axes
)
{
}
//================================================================================================
//================================================================================================
// UnaryElementWiseOp
// UnaryElementWiseOp
...
...
src/transformers/op_graph.hpp
View file @
8c16125d
This source diff could not be displayed because it is too large. You can
view the blob
instead.
src/tree.hpp
View file @
8c16125d
#pragma once
#pragma once
#include <algorithm>
#include <functional>
#include <functional>
#include <vector>
#include <initializer_list>
#include <initializer_list>
#include <iostream>
#include <iostream>
#include <
algorithm
>
#include <
vector
>
#include "util.hpp"
#include "util.hpp"
...
@@ -51,7 +51,6 @@ public:
...
@@ -51,7 +51,6 @@ public:
bool
is_list
()
const
{
return
m_is_list
;
}
bool
is_list
()
const
{
return
m_is_list
;
}
T
get_value
()
const
{
return
m_value
;
}
T
get_value
()
const
{
return
m_value
;
}
const
std
::
vector
<
tree
>&
get_list
()
const
{
return
m_list
;
}
const
std
::
vector
<
tree
>&
get_list
()
const
{
return
m_list
;
}
static
void
traverse_tree
(
tree
&
s
,
std
::
function
<
void
(
T
*
)
>
func
)
static
void
traverse_tree
(
tree
&
s
,
std
::
function
<
void
(
T
*
)
>
func
)
{
{
if
(
s
.
is_list
())
if
(
s
.
is_list
())
...
...
src/util.cpp
View file @
8c16125d
...
@@ -12,8 +12,8 @@
...
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
// ----------------------------------------------------------------------------
#include <map>
#include <iomanip>
#include <iomanip>
#include <map>
#include "util.hpp"
#include "util.hpp"
...
...
src/util.hpp
View file @
8c16125d
...
@@ -14,23 +14,22 @@
...
@@ -14,23 +14,22 @@
#pragma once
#pragma once
#include <string>
#include <sstream>
#include <vector>
#include <chrono>
#include <algorithm>
#include <algorithm>
#include <
map
>
#include <
chrono
>
#include <iostream>
#include <iostream>
#include <map>
#include <sstream>
#include <string>
#include <vector>
namespace
ngraph
namespace
ngraph
{
{
class
stopwatch
;
extern
std
::
map
<
std
::
string
,
stopwatch
*>
stopwatch_statistics
;
class
stopwatch
;
template
<
typename
T
>
extern
std
::
map
<
std
::
string
,
stopwatch
*>
stopwatch_statistics
;
std
::
string
join
(
const
T
&
v
,
const
std
::
string
&
sep
)
{
template
<
typename
T
>
std
::
string
join
(
const
T
&
v
,
const
std
::
string
&
sep
)
{
std
::
ostringstream
ss
;
std
::
ostringstream
ss
;
for
(
const
auto
&
x
:
v
)
for
(
const
auto
&
x
:
v
)
{
{
...
@@ -41,11 +40,11 @@ std::string join(const T& v, const std::string& sep)
...
@@ -41,11 +40,11 @@ std::string join(const T& v, const std::string& sep)
ss
<<
x
;
ss
<<
x
;
}
}
return
ss
.
str
();
return
ss
.
str
();
}
}
template
<
typename
U
,
typename
T
>
template
<
typename
U
,
typename
T
>
bool
contains
(
const
U
&
container
,
const
T
&
obj
)
bool
contains
(
const
U
&
container
,
const
T
&
obj
)
{
{
bool
rc
=
false
;
bool
rc
=
false
;
for
(
auto
o
:
container
)
for
(
auto
o
:
container
)
{
{
...
@@ -56,11 +55,11 @@ bool contains(const U& container, const T& obj)
...
@@ -56,11 +55,11 @@ bool contains(const U& container, const T& obj)
}
}
}
}
return
rc
;
return
rc
;
}
}
template
<
typename
U
,
typename
T
>
template
<
typename
U
,
typename
T
>
bool
contains_key
(
const
U
&
container
,
const
T
&
obj
)
bool
contains_key
(
const
U
&
container
,
const
T
&
obj
)
{
{
bool
rc
=
false
;
bool
rc
=
false
;
for
(
auto
o
:
container
)
for
(
auto
o
:
container
)
{
{
...
@@ -71,28 +70,28 @@ bool contains_key(const U& container, const T& obj)
...
@@ -71,28 +70,28 @@ bool contains_key(const U& container, const T& obj)
}
}
}
}
return
rc
;
return
rc
;
}
}
template
<
typename
U
,
typename
T
>
template
<
typename
U
,
typename
T
>
void
remove_from
(
U
&
container
,
const
T
&
obj
)
void
remove_from
(
U
&
container
,
const
T
&
obj
)
{
{
auto
it
=
container
.
find
(
obj
);
auto
it
=
container
.
find
(
obj
);
if
(
it
!=
container
.
end
())
if
(
it
!=
container
.
end
())
{
{
container
.
erase
(
it
);
container
.
erase
(
it
);
}
}
}
}
size_t
hash_combine
(
const
std
::
vector
<
size_t
>&
list
);
size_t
hash_combine
(
const
std
::
vector
<
size_t
>&
list
);
void
dump
(
std
::
ostream
&
out
,
const
void
*
,
size_t
);
void
dump
(
std
::
ostream
&
out
,
const
void
*
,
size_t
);
std
::
string
to_lower
(
const
std
::
string
&
s
);
std
::
string
to_lower
(
const
std
::
string
&
s
);
std
::
string
trim
(
const
std
::
string
&
s
);
std
::
string
trim
(
const
std
::
string
&
s
);
std
::
vector
<
std
::
string
>
split
(
const
std
::
string
&
s
,
char
delimiter
,
bool
trim
=
false
);
std
::
vector
<
std
::
string
>
split
(
const
std
::
string
&
s
,
char
delimiter
,
bool
trim
=
false
);
class
stopwatch
class
stopwatch
{
{
public
:
public
:
stopwatch
()
{}
stopwatch
()
{}
stopwatch
(
const
std
::
string
&
name
)
stopwatch
(
const
std
::
string
&
name
)
:
m_name
{
name
}
:
m_name
{
name
}
...
@@ -149,21 +148,21 @@ public:
...
@@ -149,21 +148,21 @@ public:
size_t
get_total_milliseconds
()
const
{
return
get_total_nanoseconds
()
/
1e6
;
}
size_t
get_total_milliseconds
()
const
{
return
get_total_nanoseconds
()
/
1e6
;
}
size_t
get_total_microseconds
()
const
{
return
get_total_nanoseconds
()
/
1e3
;
}
size_t
get_total_microseconds
()
const
{
return
get_total_nanoseconds
()
/
1e3
;
}
size_t
get_total_nanoseconds
()
const
{
return
m_total_time
.
count
();
}
size_t
get_total_nanoseconds
()
const
{
return
m_total_time
.
count
();
}
private
:
private
:
std
::
chrono
::
high_resolution_clock
m_clock
;
std
::
chrono
::
high_resolution_clock
m_clock
;
std
::
chrono
::
time_point
<
std
::
chrono
::
high_resolution_clock
>
m_start_time
;
std
::
chrono
::
time_point
<
std
::
chrono
::
high_resolution_clock
>
m_start_time
;
bool
m_active
=
false
;
bool
m_active
=
false
;
std
::
chrono
::
nanoseconds
m_total_time
=
std
::
chrono
::
high_resolution_clock
::
duration
::
zero
();
std
::
chrono
::
nanoseconds
m_total_time
=
std
::
chrono
::
high_resolution_clock
::
duration
::
zero
();
std
::
chrono
::
nanoseconds
m_last_time
;
std
::
chrono
::
nanoseconds
m_last_time
;
size_t
m_total_count
=
0
;
size_t
m_total_count
=
0
;
std
::
string
m_name
;
std
::
string
m_name
;
};
};
template
<
class
InputIt
,
class
BinaryOp
>
template
<
class
InputIt
,
class
BinaryOp
>
typename
std
::
iterator_traits
<
InputIt
>::
value_type
typename
std
::
iterator_traits
<
InputIt
>::
value_type
reduce
(
InputIt
first
,
InputIt
last
,
BinaryOp
op
)
reduce
(
InputIt
first
,
InputIt
last
,
BinaryOp
op
)
{
{
typename
std
::
iterator_traits
<
InputIt
>::
value_type
result
;
typename
std
::
iterator_traits
<
InputIt
>::
value_type
result
;
if
(
first
==
last
)
if
(
first
==
last
)
...
@@ -180,18 +179,18 @@ typename std::iterator_traits<InputIt>::value_type
...
@@ -180,18 +179,18 @@ typename std::iterator_traits<InputIt>::value_type
}
}
}
}
return
result
;
return
result
;
}
}
template
<
typename
T
>
template
<
typename
T
>
T
plus
(
const
T
&
a
,
const
T
&
b
)
T
plus
(
const
T
&
a
,
const
T
&
b
)
{
{
return
a
+
b
;
return
a
+
b
;
}
}
template
<
typename
T
>
template
<
typename
T
>
T
mul
(
const
T
&
a
,
const
T
&
b
)
T
mul
(
const
T
&
a
,
const
T
&
b
)
{
{
return
a
*
b
;
return
a
*
b
;
}
}
}
// end namespace ngraph
}
// end namespace ngraph
src/uuid.hpp
View file @
8c16125d
...
@@ -15,10 +15,10 @@
...
@@ -15,10 +15,10 @@
#pragma once
#pragma once
#include <array>
#include <array>
#include <
random
>
#include <
cstring
>
#include <iomanip>
#include <iomanip>
#include <iostream>
#include <iostream>
#include <
cstring
>
#include <
random
>
static
std
::
mt19937_64
random_generator
;
static
std
::
mt19937_64
random_generator
;
...
@@ -74,7 +74,6 @@ public:
...
@@ -74,7 +74,6 @@ public:
}
}
bool
operator
!=
(
const
uuid_type
&
other
)
const
{
return
!
(
*
this
==
other
);
}
bool
operator
!=
(
const
uuid_type
&
other
)
const
{
return
!
(
*
this
==
other
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
uuid_type
&
id
)
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
uuid_type
&
id
)
{
{
out
<<
id
.
to_string
();
out
<<
id
.
to_string
();
...
...
test/axes.cpp
View file @
8c16125d
...
@@ -12,10 +12,10 @@
...
@@ -12,10 +12,10 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
// ----------------------------------------------------------------------------
#include <vector>
#include <string>
#include <sstream>
#include <sstream>
#include <string>
#include <unordered_map>
#include <unordered_map>
#include <vector>
#include "gtest/gtest.h"
#include "gtest/gtest.h"
...
@@ -310,7 +310,7 @@ TEST(axes, index)
...
@@ -310,7 +310,7 @@ TEST(axes, index)
EXPECT_EQ
(
7
,
b
[
1
].
length
());
EXPECT_EQ
(
7
,
b
[
1
].
length
());
}
}
TEST
(
axes
,
as_nested_list
)
TEST
(
axes
,
DISABLED_
as_nested_list
)
{
{
Axis
C
=
make_axis
(
5
);
Axis
C
=
make_axis
(
5
);
Axis
H
=
make_axis
(
3
);
Axis
H
=
make_axis
(
3
);
...
@@ -325,7 +325,7 @@ TEST(axes, as_nested_list)
...
@@ -325,7 +325,7 @@ TEST(axes, as_nested_list)
FAIL
();
FAIL
();
}
}
TEST
(
axes
,
flatten
)
TEST
(
axes
,
DISABLED_
flatten
)
{
{
Axis
C
=
make_axis
(
5
);
Axis
C
=
make_axis
(
5
);
Axis
H
=
make_axis
(
3
);
Axis
H
=
make_axis
(
3
);
...
@@ -336,7 +336,7 @@ TEST(axes, flatten)
...
@@ -336,7 +336,7 @@ TEST(axes, flatten)
EXPECT_TRUE
(
c
.
is_flattened
());
EXPECT_TRUE
(
c
.
is_flattened
());
}
}
TEST
(
axes
,
as_flattened_list
)
TEST
(
axes
,
DISABLED_
as_flattened_list
)
{
{
FAIL
();
FAIL
();
}
}
...
@@ -364,7 +364,7 @@ TEST(axes, hash_axes)
...
@@ -364,7 +364,7 @@ TEST(axes, hash_axes)
m2
[
axes
]
=
1
;
m2
[
axes
]
=
1
;
}
}
TEST
(
axes
,
reaxe_0d_to_1d
)
TEST
(
axes
,
DISABLED_
reaxe_0d_to_1d
)
{
{
TensorDescription
td
{};
TensorDescription
td
{};
ngraph
::
ndarray
x
=
random
(
td
);
ngraph
::
ndarray
x
=
random
(
td
);
...
@@ -382,7 +382,7 @@ TEST(axes, reaxe_0d_to_1d)
...
@@ -382,7 +382,7 @@ TEST(axes, reaxe_0d_to_1d)
FAIL
();
FAIL
();
}
}
TEST
(
axes
,
reaxe_0d_to_2d
)
TEST
(
axes
,
DISABLED_
reaxe_0d_to_2d
)
{
{
// td = TensorDescription(axes=())
// td = TensorDescription(axes=())
// x = random(td)
// x = random(td)
...
@@ -407,7 +407,7 @@ TEST(axes, reaxe_0d_to_2d)
...
@@ -407,7 +407,7 @@ TEST(axes, reaxe_0d_to_2d)
// I started refactoring into smaller pieces as seen in tests above, but
// I started refactoring into smaller pieces as seen in tests above, but
// stopped ...
// stopped ...
//-----------------------------------------------------------------------------------------------
//-----------------------------------------------------------------------------------------------
TEST
(
axes
,
simple_tensors
)
TEST
(
axes
,
DISABLED_
simple_tensors
)
{
{
// # A simple vector
// # A simple vector
// td1 = TensorDescription(axes=[ax_A])
// td1 = TensorDescription(axes=[ax_A])
...
@@ -582,7 +582,7 @@ TEST(axes, axes_map)
...
@@ -582,7 +582,7 @@ TEST(axes, axes_map)
// assert axes_after == axes_map.map_axes(axes_before)
// assert axes_after == axes_map.map_axes(axes_before)
}
}
TEST
(
axes
,
axes_map_immutable
)
TEST
(
axes
,
DISABLED_
axes_map_immutable
)
{
{
FAIL
();
FAIL
();
// axes_map = AxesMap({})
// axes_map = AxesMap({})
...
@@ -591,7 +591,7 @@ TEST(axes, axes_map_immutable)
...
@@ -591,7 +591,7 @@ TEST(axes, axes_map_immutable)
// axes_map["x"] = "y"
// axes_map["x"] = "y"
}
}
TEST
(
axes
,
axes_map_init_from_axes
)
TEST
(
axes
,
DISABLED_
axes_map_init_from_axes
)
{
{
FAIL
();
FAIL
();
// axes_map = AxesMap({ng.make_axis(1, name="aaa"): ng.make_axis(1, name="zzz")})
// axes_map = AxesMap({ng.make_axis(1, name="aaa"): ng.make_axis(1, name="zzz")})
...
...
test/element_type.cpp
View file @
8c16125d
...
@@ -12,9 +12,9 @@
...
@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
// ----------------------------------------------------------------------------
#include <vector>
#include <string>
#include <sstream>
#include <sstream>
#include <string>
#include <vector>
#include "gtest/gtest.h"
#include "gtest/gtest.h"
...
...
test/exop.cpp
View file @
8c16125d
...
@@ -12,9 +12,9 @@
...
@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
// ----------------------------------------------------------------------------
#include <vector>
#include <string>
#include <sstream>
#include <sstream>
#include <string>
#include <vector>
#include "gtest/gtest.h"
#include "gtest/gtest.h"
...
...
test/main.cpp
View file @
8c16125d
...
@@ -12,8 +12,8 @@
...
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
// ----------------------------------------------------------------------------
#include <iostream>
#include <chrono>
#include <chrono>
#include <iostream>
#include "gtest/gtest.h"
#include "gtest/gtest.h"
...
...
test/names.cpp
View file @
8c16125d
...
@@ -12,9 +12,9 @@
...
@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
// ----------------------------------------------------------------------------
#include <vector>
#include <string>
#include <sstream>
#include <sstream>
#include <string>
#include <vector>
#include "gtest/gtest.h"
#include "gtest/gtest.h"
...
@@ -22,4 +22,6 @@
...
@@ -22,4 +22,6 @@
using
namespace
ngraph
;
using
namespace
ngraph
;
TEST
(
names
,
name
)
{}
TEST
(
names
,
name
)
{
}
test/op_graph.cpp
View file @
8c16125d
...
@@ -12,9 +12,9 @@
...
@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
// ----------------------------------------------------------------------------
#include <vector>
#include <string>
#include <sstream>
#include <sstream>
#include <string>
#include <vector>
#include "gtest/gtest.h"
#include "gtest/gtest.h"
...
...
test/strides.cpp
View file @
8c16125d
...
@@ -12,10 +12,10 @@
...
@@ -12,10 +12,10 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
// ----------------------------------------------------------------------------
#include <vector>
#include <string>
#include <sstream>
#include <sstream>
#include <string>
#include <unordered_map>
#include <unordered_map>
#include <vector>
#include "gtest/gtest.h"
#include "gtest/gtest.h"
...
...
test/tensor.cpp
View file @
8c16125d
...
@@ -12,10 +12,10 @@
...
@@ -12,10 +12,10 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
// ----------------------------------------------------------------------------
#include <vector>
#include <string>
#include <sstream>
#include <memory>
#include <memory>
#include <sstream>
#include <string>
#include <vector>
#include "gtest/gtest.h"
#include "gtest/gtest.h"
...
...
test/util.cpp
View file @
8c16125d
...
@@ -12,9 +12,9 @@
...
@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
// ----------------------------------------------------------------------------
#include <vector>
#include <string>
#include <sstream>
#include <sstream>
#include <string>
#include <vector>
#include "gtest/gtest.h"
#include "gtest/gtest.h"
...
@@ -134,7 +134,9 @@ TEST(util, contains)
...
@@ -134,7 +134,9 @@ TEST(util, contains)
EXPECT_FALSE
(
contains
(
v1
,
8
));
EXPECT_FALSE
(
contains
(
v1
,
8
));
}
}
TEST
(
util
,
remove_from
)
{}
TEST
(
util
,
remove_from
)
{
}
TEST
(
util
,
reduce
)
TEST
(
util
,
reduce
)
{
{
...
...
test/uuid.cpp
View file @
8c16125d
...
@@ -12,9 +12,9 @@
...
@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
// ----------------------------------------------------------------------------
#include <vector>
#include <string>
#include <sstream>
#include <sstream>
#include <string>
#include <vector>
#include "gtest/gtest.h"
#include "gtest/gtest.h"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment