Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
94f7f22c
Commit
94f7f22c
authored
Aug 08, 2017
by
Robert Kimball
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
put everything ngraph in the ngraph namespace
parent
bdf38828
Hide whitespace changes
Inline
Side-by-side
Showing
25 changed files
with
164 additions
and
238 deletions
+164
-238
element_type.cpp
src/element_type.cpp
+12
-12
element_type.hpp
src/element_type.hpp
+13
-8
names.cpp
src/names.cpp
+2
-0
names.hpp
src/names.hpp
+6
-0
axes.cpp
src/transformers/axes.cpp
+9
-7
axes.hpp
src/transformers/axes.hpp
+36
-32
exop.cpp
src/transformers/exop.cpp
+6
-4
exop.hpp
src/transformers/exop.hpp
+5
-0
mock.hpp
src/transformers/mock.hpp
+5
-0
mock_transformer.hpp
src/transformers/mock_transformer.hpp
+5
-0
op_graph.cpp
src/transformers/op_graph.cpp
+4
-2
op_graph.hpp
src/transformers/op_graph.hpp
+5
-0
util.cpp
src/util.cpp
+6
-6
util.hpp
src/util.hpp
+30
-28
uuid.hpp
src/uuid.hpp
+6
-1
CMakeLists.txt
test/CMakeLists.txt
+1
-1
axes.cpp
test/axes.cpp
+1
-0
element_type.cpp
test/element_type.cpp
+2
-0
exop.cpp
test/exop.cpp
+2
-0
names.cpp
test/names.cpp
+2
-0
op.cpp
test/op.cpp
+0
-136
op_graph.cpp
test/op_graph.cpp
+3
-1
strides.cpp
test/strides.cpp
+1
-0
util.cpp
test/util.cpp
+1
-0
uuid.cpp
test/uuid.cpp
+1
-0
No files found.
src/element_type.cpp
View file @
94f7f22c
...
...
@@ -17,17 +17,17 @@
#include "element_type.hpp"
const
ElementType
element_type_float
=
ElementType
(
32
,
true
,
true
,
"float"
);
const
ElementType
element_type_int8_t
=
ElementType
(
8
,
false
,
true
,
"int8_t"
);
const
ElementType
element_type_int32_t
=
ElementType
(
32
,
false
,
true
,
"int32_t"
);
const
ElementType
element_type_int64_t
=
ElementType
(
64
,
false
,
true
,
"int64_t"
);
const
ElementType
element_type_uint8_t
=
ElementType
(
8
,
false
,
false
,
"int8_t"
);
const
ElementType
element_type_uint32_t
=
ElementType
(
32
,
false
,
false
,
"int32_t"
);
const
ElementType
element_type_uint64_t
=
ElementType
(
64
,
false
,
false
,
"int64_t"
);
const
ngraph
::
ElementType
element_type_float
=
ngraph
::
ElementType
(
32
,
true
,
true
,
"float"
);
const
ngraph
::
ElementType
element_type_int8_t
=
ngraph
::
ElementType
(
8
,
false
,
true
,
"int8_t"
);
const
ngraph
::
ElementType
element_type_int32_t
=
ngraph
::
ElementType
(
32
,
false
,
true
,
"int32_t"
);
const
ngraph
::
ElementType
element_type_int64_t
=
ngraph
::
ElementType
(
64
,
false
,
true
,
"int64_t"
);
const
ngraph
::
ElementType
element_type_uint8_t
=
ngraph
::
ElementType
(
8
,
false
,
false
,
"int8_t"
);
const
ngraph
::
ElementType
element_type_uint32_t
=
ngraph
::
ElementType
(
32
,
false
,
false
,
"int32_t"
);
const
ngraph
::
ElementType
element_type_uint64_t
=
ngraph
::
ElementType
(
64
,
false
,
false
,
"int64_t"
);
std
::
map
<
std
::
string
,
ElementType
>
ElementType
::
m_element_list
;
std
::
map
<
std
::
string
,
ngraph
::
ElementType
>
ngraph
::
ElementType
::
m_element_list
;
ElementType
::
ElementType
(
size_t
bitwidth
,
bool
is_float
,
bool
is_signed
,
const
std
::
string
&
cname
)
ngraph
::
ElementType
::
ElementType
(
size_t
bitwidth
,
bool
is_float
,
bool
is_signed
,
const
std
::
string
&
cname
)
:
m_bitwidth
{
bitwidth
}
,
m_is_float
{
is_float
}
,
m_is_signed
{
is_signed
}
...
...
@@ -36,18 +36,18 @@ ElementType::ElementType(size_t bitwidth, bool is_float, bool is_signed, const s
assert
(
m_bitwidth
%
8
==
0
);
}
const
std
::
string
&
ElementType
::
c_type_string
()
const
const
std
::
string
&
ngraph
::
ElementType
::
c_type_string
()
const
{
return
m_cname
;
}
bool
ElementType
::
operator
==
(
const
ElementType
&
other
)
const
bool
ngraph
::
ElementType
::
operator
==
(
const
ElementType
&
other
)
const
{
return
m_bitwidth
==
other
.
m_bitwidth
&&
m_is_float
==
other
.
m_is_float
&&
m_is_signed
==
other
.
m_is_signed
;
}
size_t
ElementType
::
size
()
const
size_t
ngraph
::
ElementType
::
size
()
const
{
return
std
::
ceil
((
float
)
m_bitwidth
/
8.0
);
}
src/element_type.hpp
View file @
94f7f22c
...
...
@@ -21,7 +21,12 @@
#include <string>
#include <map>
class
ElementType
namespace
ngraph
{
class
ElementType
;
}
class
ngraph
::
ElementType
{
public
:
ElementType
(
size_t
bitwidth
,
bool
is_float
,
bool
is_signed
,
const
std
::
string
&
cname
);
...
...
@@ -44,10 +49,10 @@ private:
const
std
::
string
m_cname
;
};
extern
const
ElementType
element_type_float
;
extern
const
ElementType
element_type_int8_t
;
extern
const
ElementType
element_type_int32_t
;
extern
const
ElementType
element_type_int64_t
;
extern
const
ElementType
element_type_uint8_t
;
extern
const
ElementType
element_type_uint32_t
;
extern
const
ElementType
element_type_uint64_t
;
extern
const
ngraph
::
ElementType
element_type_float
;
extern
const
ngraph
::
ElementType
element_type_int8_t
;
extern
const
ngraph
::
ElementType
element_type_int32_t
;
extern
const
ngraph
::
ElementType
element_type_int64_t
;
extern
const
ngraph
::
ElementType
element_type_uint8_t
;
extern
const
ngraph
::
ElementType
element_type_uint32_t
;
extern
const
ngraph
::
ElementType
element_type_uint64_t
;
src/names.cpp
View file @
94f7f22c
...
...
@@ -16,6 +16,8 @@
#include "names.hpp"
using
namespace
ngraph
;
size_t
NameableValue
::
__counter
=
0
;
std
::
map
<
std
::
string
,
NameableValue
>
NameableValue
::
__all_names
;
...
...
src/names.hpp
View file @
94f7f22c
...
...
@@ -17,6 +17,9 @@
#include <string>
#include <map>
namespace
ngraph
{
//================================================================================================
// NameableValue
// An Axis labels a dimension of a tensor. The op-graph uses
...
...
@@ -101,3 +104,6 @@ public:
std
::
string
m_short_name
;
std
::
string
m_doc_string
;
};
}
// end namespace ngraph
src/transformers/axes.cpp
View file @
94f7f22c
...
...
@@ -20,6 +20,8 @@
#include "axes.hpp"
#include "util.hpp"
using
namespace
ngraph
;
slice
::
slice
(
int64_t
start
,
int64_t
stop
,
int64_t
step
)
:
m_start
{(
size_t
)
start
}
,
m_stop
{(
size_t
)
stop
}
...
...
@@ -90,12 +92,12 @@ size_t slice::sliced_length(size_t length) const
// raise TypeError("Could not cast {} to np.dtype".format(dtype))
// return dtype
Axis
make_axis
(
size_t
length
,
const
std
::
string
&
name
,
bool
batch
,
bool
recurrent
)
Axis
ngraph
::
make_axis
(
size_t
length
,
const
std
::
string
&
name
,
bool
batch
,
bool
recurrent
)
{
return
Axis
(
length
,
name
);
}
Axes
make_axes
(
const
std
::
vector
<
Axis
>&
axis_list
)
Axes
ngraph
::
make_axes
(
const
std
::
vector
<
Axis
>&
axis_list
)
{
return
Axes
(
axis_list
);
}
...
...
@@ -175,7 +177,7 @@ void Axis::length(size_t l)
__length
=
l
;
}
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
Axis
&
axis
)
std
::
ostream
&
ngraph
::
operator
<<
(
std
::
ostream
&
out
,
const
Axis
&
axis
)
{
out
<<
axis
.
to_string
();
return
out
;
...
...
@@ -238,7 +240,7 @@ bool Axis::operator<(const Axis& other) const
// // ))
// }
Axis
slice_axis
(
const
Axis
&
axis
,
const
slice
&
s
)
Axis
ngraph
::
slice_axis
(
const
Axis
&
axis
,
const
slice
&
s
)
{
// _validate_slice(s)
...
...
@@ -263,7 +265,7 @@ Axis slice_axis(const Axis& axis, const slice& s)
// Returns:
// list of Axis: duplicate Axis found in arr
// """
std
::
vector
<
std
::
string
>
duplicates
(
const
std
::
vector
<
Axis
>&
ax
)
std
::
vector
<
std
::
string
>
ngraph
::
duplicates
(
const
std
::
vector
<
Axis
>&
ax
)
{
std
::
map
<
std
::
string
,
size_t
>
counts
;
std
::
vector
<
std
::
string
>
rc
;
...
...
@@ -835,7 +837,7 @@ bool Axes::operator<(const Axes& other) const
// """
// return int(np.prod(self.lengths))
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
Axes
&
axes
)
std
::
ostream
&
ngraph
::
operator
<<
(
std
::
ostream
&
out
,
const
Axes
&
axes
)
{
out
<<
"Axes("
;
out
<<
join
(
axes
.
axes
,
", "
);
...
...
@@ -1060,7 +1062,7 @@ FlattenedAxis::FlattenedAxis(const std::vector<Axis>& list, const std::string& n
axes
=
list
;
}
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
FlattenedAxis
&
obj
)
std
::
ostream
&
ngraph
::
operator
<<
(
std
::
ostream
&
out
,
const
FlattenedAxis
&
obj
)
{
out
<<
obj
.
to_string
();
return
out
;
...
...
src/transformers/axes.hpp
View file @
94f7f22c
...
...
@@ -28,6 +28,8 @@
#include "strides.hpp"
#include "uuid.hpp"
namespace
ngraph
{
class
Axes
;
class
Axis
;
class
FlattenedAxis
;
...
...
@@ -232,20 +234,6 @@ public:
static
size_t
__name_counter
;
};
namespace
std
{
template
<>
struct
std
::
hash
<
Axis
>
{
size_t
operator
()(
const
Axis
&
axis
)
const
{
std
::
hash
<
std
::
string
>
h1
;
std
::
hash
<
size_t
>
h2
;
return
hash_combine
({
h1
(
axis
.
name
),
h2
(
axis
.
length
())});
}
};
}
//-----------------------------------------------------------------------------------------------
// _sliced_length
//-----------------------------------------------------------------------------------------------
...
...
@@ -722,24 +710,6 @@ private:
void
check_duplicates
();
};
namespace
std
{
template
<>
struct
std
::
hash
<
Axes
>
{
size_t
operator
()(
const
Axes
&
axes
)
const
{
std
::
hash
<
Axis
>
h1
;
std
::
vector
<
size_t
>
hashes
;
for
(
auto
axis
:
axes
)
{
hashes
.
push_back
(
h1
(
axis
));
}
return
hash_combine
(
hashes
);
}
};
}
//================================================================================================
// DuplicateAxisNames
//================================================================================================
...
...
@@ -1518,3 +1488,37 @@ public:
ngraph
::
tensor_stride
full_strides
;
tensor_description_ptr
next_tensor_description
;
};
}
// end of namespace ngraph
namespace
std
{
template
<>
struct
std
::
hash
<
ngraph
::
Axis
>
{
size_t
operator
()(
const
ngraph
::
Axis
&
axis
)
const
{
std
::
hash
<
std
::
string
>
h1
;
std
::
hash
<
size_t
>
h2
;
return
ngraph
::
hash_combine
({
h1
(
axis
.
name
),
h2
(
axis
.
length
())});
}
};
}
namespace
std
{
template
<>
struct
std
::
hash
<
ngraph
::
Axes
>
{
size_t
operator
()(
const
ngraph
::
Axes
&
axes
)
const
{
std
::
hash
<
ngraph
::
Axis
>
h1
;
std
::
vector
<
size_t
>
hashes
;
for
(
auto
axis
:
axes
)
{
hashes
.
push_back
(
h1
(
axis
));
}
return
ngraph
::
hash_combine
(
hashes
);
}
};
}
src/transformers/exop.cpp
View file @
94f7f22c
...
...
@@ -21,6 +21,8 @@
#include "op_graph.hpp"
#include "util.hpp"
using
namespace
ngraph
;
//================================================================================================
// InputDecl
//================================================================================================
...
...
@@ -77,7 +79,7 @@ void InputDecl::value(OutputDecl* value)
}
}
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
InputDecl
&
obj
)
std
::
ostream
&
ngraph
::
operator
<<
(
std
::
ostream
&
out
,
const
InputDecl
&
obj
)
{
out
<<
"Arg("
<<
obj
.
exop
.
name
()
<<
obj
.
pos
<<
")"
;
return
out
;
...
...
@@ -142,7 +144,7 @@ void OutputDecl::write_view(tensor_view_decl_ptr view)
}
}
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
OutputDecl
&
obj
)
std
::
ostream
&
ngraph
::
operator
<<
(
std
::
ostream
&
out
,
const
OutputDecl
&
obj
)
{
out
<<
"Val("
<<
obj
.
exop
.
name
()
<<
":"
<<
obj
.
pos
<<
")"
;
return
out
;
...
...
@@ -191,7 +193,7 @@ ExOp::ExOp(ComputationDecl& cgraph, op_ptr _op, bool create_value)
}
}
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
ExOp
&
obj
)
std
::
ostream
&
ngraph
::
operator
<<
(
std
::
ostream
&
out
,
const
ExOp
&
obj
)
{
out
<<
obj
.
op
->
name
();
std
::
vector
<
std
::
string
>
args
;
...
...
@@ -833,7 +835,7 @@ std::string TensorDecl::buffer_name()
// return op->name();
// }
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
TensorDecl
&
obj
)
std
::
ostream
&
ngraph
::
operator
<<
(
std
::
ostream
&
out
,
const
TensorDecl
&
obj
)
{
out
<<
obj
.
tensor_description_base
->
name
();
return
out
;
...
...
src/transformers/exop.hpp
View file @
94f7f22c
...
...
@@ -27,6 +27,9 @@
#include "op_graph.hpp"
#include "axes.hpp"
namespace
ngraph
{
// forward declaration. This will hopefully go away
class
ExecutionGraph
;
class
TensorDescription
;
...
...
@@ -450,3 +453,5 @@ public:
std
::
map
<
tensor_description_ptr
,
tensor_decl_ptr
>
tensor_decls
;
computation_decl_ptr
computation_decl
;
};
}
// end namespace ngraph
src/transformers/mock.hpp
View file @
94f7f22c
...
...
@@ -23,6 +23,9 @@
#include "element_type.hpp"
namespace
ngraph
{
class
ExecutionState
;
class
Op
;
...
...
@@ -175,3 +178,5 @@ public:
// private:
// std::vector<op_ptr> m_all_deps;
// };
}
// end of namespace ngraph
src/transformers/mock_transformer.hpp
View file @
94f7f22c
...
...
@@ -17,6 +17,9 @@
#include "mock.hpp"
#include "exop.hpp"
namespace
ngraph
{
//================================================================================================
// CpuTransformer
//================================================================================================
...
...
@@ -30,3 +33,5 @@ public:
private
:
ExecutionState
m_execution_state
;
};
}
// end namespace ngraph
src/transformers/op_graph.cpp
View file @
94f7f22c
...
...
@@ -18,6 +18,8 @@
#include "axes.hpp"
#include "util.hpp"
using
namespace
ngraph
;
// def tensor_descriptions(args):
// """
// A list of tensor descriptions for Ops.
...
...
@@ -1913,7 +1915,7 @@ BroadcastOp::BroadcastOp(op_ptr x, Axes axes)
// dx_reordered = axes_with_order(dx, x.axes)
// x.generate_add_delta(adjoints, dx_reordered)
op_ptr
broadcast
(
op_ptr
x
,
const
Axes
&
axes
)
op_ptr
ngraph
::
broadcast
(
op_ptr
x
,
const
Axes
&
axes
)
{
// auto axes = make_axes(axis_list);
op_ptr
rc
;
...
...
@@ -1928,7 +1930,7 @@ op_ptr broadcast(op_ptr x, const Axes& axes)
return
rc
;
}
op_ptr
axes_with_order
(
op_ptr
x
,
const
std
::
vector
<
Axis
>&
axis_list
)
op_ptr
ngraph
::
axes_with_order
(
op_ptr
x
,
const
std
::
vector
<
Axis
>&
axis_list
)
{
auto
axes
=
make_axes
(
axis_list
);
op_ptr
rc
;
...
...
src/transformers/op_graph.hpp
View file @
94f7f22c
...
...
@@ -23,6 +23,9 @@
#include "axes.hpp"
#include "names.hpp"
namespace
ngraph
{
class
Op
;
class
AssignableTensorOp
;
class
ParallelOp
;
...
...
@@ -4427,3 +4430,5 @@ public:
// private:
// std::vector<op_ptr> m_all_deps;
};
}
// end namespace ngraph
src/util.cpp
View file @
94f7f22c
...
...
@@ -19,9 +19,9 @@
using
namespace
std
;
map
<
string
,
stopwatch
*>
stopwatch_statistics
;
map
<
string
,
ngraph
::
stopwatch
*>
ngraph
::
stopwatch_statistics
;
void
dump
(
ostream
&
out
,
const
void
*
_data
,
size_t
_size
)
void
ngraph
::
dump
(
ostream
&
out
,
const
void
*
_data
,
size_t
_size
)
{
auto
flags
=
out
.
flags
();
const
uint8_t
*
data
=
reinterpret_cast
<
const
uint8_t
*>
(
_data
);
...
...
@@ -66,14 +66,14 @@ void dump(ostream& out, const void* _data, size_t _size)
out
.
flags
(
flags
);
}
std
::
string
to_lower
(
const
std
::
string
&
s
)
std
::
string
ngraph
::
to_lower
(
const
std
::
string
&
s
)
{
std
::
string
rc
=
s
;
std
::
transform
(
rc
.
begin
(),
rc
.
end
(),
rc
.
begin
(),
::
tolower
);
return
rc
;
}
string
trim
(
const
string
&
s
)
string
ngraph
::
trim
(
const
string
&
s
)
{
string
rc
=
s
;
// trim trailing spaces
...
...
@@ -92,7 +92,7 @@ string trim(const string& s)
return
rc
;
}
vector
<
string
>
split
(
const
string
&
src
,
char
delimiter
,
bool
do_trim
)
vector
<
string
>
ngraph
::
split
(
const
string
&
src
,
char
delimiter
,
bool
do_trim
)
{
size_t
pos
;
string
token
;
...
...
@@ -120,7 +120,7 @@ vector<string> split(const string& src, char delimiter, bool do_trim)
return
rc
;
}
size_t
hash_combine
(
const
std
::
vector
<
size_t
>&
list
)
size_t
ngraph
::
hash_combine
(
const
std
::
vector
<
size_t
>&
list
)
{
size_t
seed
=
0
;
for
(
size_t
v
:
list
)
...
...
src/util.hpp
View file @
94f7f22c
...
...
@@ -22,6 +22,9 @@
#include <map>
#include <iostream>
namespace
ngraph
{
class
stopwatch
;
extern
std
::
map
<
std
::
string
,
stopwatch
*>
stopwatch_statistics
;
...
...
@@ -157,39 +160,38 @@ private:
std
::
string
m_name
;
};
namespace
ngraph
template
<
class
InputIt
,
class
BinaryOp
>
typename
std
::
iterator_traits
<
InputIt
>::
value_type
reduce
(
InputIt
first
,
InputIt
last
,
BinaryOp
op
)
{
template
<
class
InputIt
,
class
BinaryOp
>
typename
std
::
iterator_traits
<
InputIt
>::
value_type
reduce
(
InputIt
first
,
InputIt
last
,
BinaryOp
op
)
{
typename
std
::
iterator_traits
<
InputIt
>::
value_type
result
;
typename
std
::
iterator_traits
<
InputIt
>::
value_type
result
;
if
(
first
==
last
)
{
result
=
{};
}
else
if
(
first
==
last
)
{
result
=
{};
}
else
{
result
=
*
first
++
;
while
(
first
!=
last
)
{
result
=
*
first
++
;
while
(
first
!=
last
)
{
result
=
op
(
result
,
*
first
);
first
++
;
}
result
=
op
(
result
,
*
first
);
first
++
;
}
return
result
;
}
return
result
;
}
template
<
typename
T
>
T
plus
(
const
T
&
a
,
const
T
&
b
)
{
return
a
+
b
;
}
template
<
typename
T
>
T
plus
(
const
T
&
a
,
const
T
&
b
)
{
return
a
+
b
;
}
template
<
typename
T
>
T
mul
(
const
T
&
a
,
const
T
&
b
)
{
return
a
*
b
;
}
template
<
typename
T
>
T
mul
(
const
T
&
a
,
const
T
&
b
)
{
return
a
*
b
;
}
}
// end namespace ngraph
src/uuid.hpp
View file @
94f7f22c
...
...
@@ -22,7 +22,12 @@
static
std
::
mt19937_64
random_generator
;
class
uuid_type
namespace
ngraph
{
class
uuid_type
;
}
class
ngraph
::
uuid_type
{
public
:
uuid_type
()
...
...
test/CMakeLists.txt
View file @
94f7f22c
...
...
@@ -27,7 +27,7 @@ set (SRC
exop.cpp
axes.cpp
element_type.cpp
op.cpp
op
_graph
.cpp
uuid.cpp
names.cpp
strides.cpp
...
...
test/axes.cpp
View file @
94f7f22c
...
...
@@ -23,6 +23,7 @@
#include "transformers/ndarray.hpp"
using
namespace
std
;
using
namespace
ngraph
;
// axes for testing
static
auto
ax_A
=
make_axis
(
2
,
"A"
);
...
...
test/element_type.cpp
View file @
94f7f22c
...
...
@@ -19,3 +19,5 @@
#include "gtest/gtest.h"
#include "element_type.hpp"
using
namespace
ngraph
;
test/exop.cpp
View file @
94f7f22c
...
...
@@ -22,6 +22,8 @@
#include "transformers/mock.hpp"
#include "transformers/mock_transformer.hpp"
using
namespace
ngraph
;
TEST
(
exop
,
create
)
{
// CpuTransformer transformer;
...
...
test/names.cpp
View file @
94f7f22c
...
...
@@ -20,4 +20,6 @@
#include "names.hpp"
using
namespace
ngraph
;
TEST
(
names
,
name
)
{}
test/op.cpp
deleted
100644 → 0
View file @
bdf38828
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <vector>
#include <string>
#include <sstream>
#include "gtest/gtest.h"
#include "transformers/op_graph.hpp"
TEST
(
op
,
constant
)
{
float
expected_value
=
42
;
op_ptr
x
=
constant
(
expected_value
);
ASSERT_NE
(
nullptr
,
x
);
EXPECT_EQ
(
true
,
x
->
is_constant
());
EXPECT_EQ
(
false
,
x
->
is_input
());
EXPECT_EQ
(
true
,
x
->
is_persistent
());
EXPECT_EQ
(
false
,
x
->
is_trainable
());
EXPECT_EQ
(
false
,
x
->
is_placeholder
());
auto
ato
=
std
::
dynamic_pointer_cast
<
AssignableTensorOp
>
(
x
);
ASSERT_NE
(
nullptr
,
ato
);
// TODO: fix this
auto
ti
=
ato
->
m_value
;
ASSERT_NE
(
nullptr
,
ti
);
std
::
string
actual_value
=
ti
->
value_string
();
std
::
stringstream
ss
;
ss
<<
expected_value
;
std
::
string
expected_string
=
ss
.
str
();
EXPECT_STREQ
(
actual_value
.
c_str
(),
expected_string
.
c_str
());
}
// @pytest.fixture()
// def N():
// return ng.make_axis(length=1)
// def test_deriv_missing_connection(N):
// """
// Taking the derivative of an expression with respect to a variable not
// used to compute the expression should raise an exception.
// """
// x = ng.variable([N])
// y = ng.variable([N])
// z = ng.variable([N])
// with pytest.raises(ValueError):
// ng.deriv(x + y, z)
// def test_one():
// # Test that the cacheing on constant one used in DerivOp works.
// op = ng.variable([])
// one_0 = op.one
// one_1 = op.one
// assert one_0 is one_1
// def test_pad_invalid_paddings_length(N):
// """
// pad should raise an exception if the paddings length is not the same as the
// input dimensionality.
// """
// x = ng.variable([N])
// with pytest.raises(ValueError):
// ng.pad(x, [1, 0])
// def test_pad_0(N):
// """
// pad with length 0 should be a nop
// """
// x = ng.variable([N])
// assert ng.pad(x, [0]).axes == x.axes
// def test_pad_mixed():
// """
// mix 0 padding with non-0 padding
// """
// input_axes = ng.make_axes([
// ng.make_axis(1),
// ng.make_axis(1)
// ])
// x = ng.variable(input_axes)
// pad = ng.pad(x, [0, 1])
// assert pad.axes[0] == x.axes[0]
// assert pad.axes[1] != x.axes[1]
// def test_slice_nop():
// """
// slicing an axis shouldn't change the name
// """
// input_axes = ng.make_axes([
// ng.make_axis(1),
// ng.make_axis(1)
// ])
// x = ng.variable(input_axes)
// s = ng.tensor_slice(x, [
// slice(None, None, None),
// slice(None, None, 1),
// ])
// assert s.axes[0] == x.axes[0]
// assert s.axes[1] == x.axes[1]
// def test_tensor_slice():
// """
// slicing a tensor should work like numpy
// """
// input_axes = ng.make_axes([
// ng.make_axis(10),
// ng.make_axis(20),
// ng.make_axis(5)
// ])
// x = ng.placeholder(axes=input_axes)
// assert x[:5].axes.full_lengths == (5, 20, 5)
// assert x[:, 2:7].axes.full_lengths == (10, 5, 5)
// assert x[:5, :, :-1].axes.full_lengths == (5, 20, 4)
test/op_graph.cpp
View file @
94f7f22c
...
...
@@ -20,6 +20,8 @@
#include "transformers/op_graph.hpp"
using
namespace
ngraph
;
TEST
(
op_graph
,
constant
)
{
float
expected_value
=
42
;
...
...
@@ -62,7 +64,7 @@ Axis N()
TEST
(
op_graph
,
deriv_missing_connection
)
{
// x = ng.variable([N])
auto
x
=
variable
({
N
()});
//
auto x = variable({N()});
// y = ng.variable([N])
// z = ng.variable([N])
...
...
test/strides.cpp
View file @
94f7f22c
...
...
@@ -22,6 +22,7 @@
#include "strides.hpp"
using
namespace
std
;
using
namespace
ngraph
;
TEST
(
strides
,
scalar_tree_ctor
)
{
...
...
test/util.cpp
View file @
94f7f22c
...
...
@@ -21,6 +21,7 @@
#include "util.hpp"
using
namespace
std
;
using
namespace
ngraph
;
TEST
(
util
,
split
)
{
...
...
test/uuid.cpp
View file @
94f7f22c
...
...
@@ -21,6 +21,7 @@
#include "uuid.hpp"
using
namespace
std
;
using
namespace
ngraph
;
TEST
(
uuid
,
zero
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment