Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
63fdc66e
Commit
63fdc66e
authored
Oct 16, 2017
by
Scott Cyphers
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add all_close for comparing tensors.
parent
d79e0353
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
162 additions
and
1 deletion
+162
-1
dense_tensor_view_layout.cpp
src/ngraph/descriptor/layout/dense_tensor_view_layout.cpp
+18
-0
dense_tensor_view_layout.hpp
src/ngraph/descriptor/layout/dense_tensor_view_layout.hpp
+3
-1
tensor_view_layout.hpp
src/ngraph/descriptor/layout/tensor_view_layout.hpp
+3
-0
tensor_view.cpp
src/ngraph/runtime/tensor_view.cpp
+6
-0
tensor_view.hpp
src/ngraph/runtime/tensor_view.hpp
+3
-0
utils.cpp
src/ngraph/runtime/utils.cpp
+58
-0
utils.hpp
src/ngraph/runtime/utils.hpp
+44
-0
element_type.hpp
src/ngraph/types/element_type.hpp
+3
-0
util.cpp
test/util.cpp
+24
-0
No files found.
src/ngraph/descriptor/layout/dense_tensor_view_layout.cpp
View file @
63fdc66e
...
...
@@ -44,3 +44,21 @@ size_t DenseTensorViewLayout::get_index_offset(const std::vector<size_t>& indice
}
return
result
;
}
bool
DenseTensorViewLayout
::
operator
==
(
const
TensorViewLayout
&
other
)
const
{
const
DenseTensorViewLayout
*
p_other
=
dynamic_cast
<
const
DenseTensorViewLayout
*>
(
&
other
);
if
(
nullptr
==
p_other
)
return
false
;
if
(
get_element_type
()
!=
p_other
->
get_element_type
())
return
false
;
if
(
m_strides
!=
p_other
->
m_strides
)
return
false
;
if
(
m_offset
!=
p_other
->
m_offset
)
return
false
;
return
true
;
}
src/ngraph/descriptor/layout/dense_tensor_view_layout.hpp
View file @
63fdc66e
...
...
@@ -41,9 +41,11 @@ namespace ngraph
virtual
size_t
get_index_offset
(
const
std
::
vector
<
size_t
>&
indices
)
override
;
const
Strides
&
get_strides
()
const
{
return
m_strides
;
}
virtual
bool
operator
==
(
const
TensorViewLayout
&
other
)
const
override
;
protected
:
Strides
m_strides
;
size_t
m_offset
;
size_t
m_offset
{
0
}
;
size_t
m_size
;
};
}
...
...
src/ngraph/descriptor/layout/tensor_view_layout.hpp
View file @
63fdc66e
...
...
@@ -59,6 +59,9 @@ namespace ngraph
/// Where this view is located in the buffer.
const
BufferPos
&
get_buffer_pos
()
const
{
return
m_buffer_pos
;
}
BufferPos
&
get_buffer_pos
()
{
return
m_buffer_pos
;
}
/// @brief Return true if this and other have the same element interpretation
virtual
bool
operator
==
(
const
TensorViewLayout
&
other
)
const
=
0
;
bool
operator
!=
(
const
TensorViewLayout
&
other
)
const
{
return
!
(
*
this
==
other
);
}
protected
:
std
::
shared_ptr
<
const
TensorViewType
>
m_tensor_view_type
;
BufferPos
m_buffer_pos
;
...
...
src/ngraph/runtime/tensor_view.cpp
View file @
63fdc66e
...
...
@@ -38,3 +38,9 @@ const ngraph::Shape& TensorView::get_shape() const
{
return
m_descriptor
->
get_tensor_view_type
()
->
get_shape
();
}
std
::
shared_ptr
<
ngraph
::
descriptor
::
layout
::
TensorViewLayout
>
TensorView
::
get_tensor_view_layout
()
const
{
return
m_descriptor
->
get_tensor_view_layout
();
}
src/ngraph/runtime/tensor_view.hpp
View file @
63fdc66e
...
...
@@ -60,6 +60,9 @@ namespace ngraph
const
ngraph
::
Shape
&
get_shape
()
const
;
std
::
shared_ptr
<
ngraph
::
descriptor
::
layout
::
TensorViewLayout
>
get_tensor_view_layout
()
const
;
/// @brief Write bytes directly into the tensor
/// @param p Pointer to source of data
/// @param tensor_offset Offset into tensor storage to begin writing. Must be element-aligned.
...
...
src/ngraph/runtime/utils.cpp
View file @
63fdc66e
...
...
@@ -12,6 +12,9 @@
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <cassert>
#include <cmath>
#include "ngraph/runtime/utils.hpp"
std
::
shared_ptr
<
ngraph
::
runtime
::
Tuple
>
ngraph
::
runtime
::
make_tuple
(
...
...
@@ -19,3 +22,58 @@ std::shared_ptr<ngraph::runtime::Tuple> ngraph::runtime::make_tuple(
{
return
std
::
make_shared
<
ngraph
::
runtime
::
Tuple
>
(
elements
);
}
template
<
typename
ET
>
bool
ngraph
::
runtime
::
all_close
(
const
std
::
shared_ptr
<
ngraph
::
runtime
::
ParameterizedTensorView
<
ET
>>&
a
,
const
std
::
shared_ptr
<
ngraph
::
runtime
::
ParameterizedTensorView
<
ET
>>&
b
,
typename
ET
::
type
rtol
,
typename
ET
::
type
atol
)
{
// Check that the layouts are compatible
if
(
*
a
->
get_tensor_view_layout
()
!=
*
b
->
get_tensor_view_layout
())
{
throw
ngraph_error
(
"Cannot compare tensors with different layouts"
);
}
if
(
a
->
get_shape
()
!=
b
->
get_shape
())
return
false
;
return
ngraph
::
runtime
::
all_close
(
a
->
get_vector
(),
b
->
get_vector
(),
rtol
,
atol
);
}
template
bool
ngraph
::
runtime
::
all_close
<
ngraph
::
element
::
Float32
>
(
const
std
::
shared_ptr
<
ngraph
::
runtime
::
ParameterizedTensorView
<
ngraph
::
element
::
Float32
>>&
a
,
const
std
::
shared_ptr
<
ngraph
::
runtime
::
ParameterizedTensorView
<
ngraph
::
element
::
Float32
>>&
b
,
ngraph
::
element
::
Float32
::
type
rtol
,
ngraph
::
element
::
Float32
::
type
atol
);
template
bool
ngraph
::
runtime
::
all_close
<
ngraph
::
element
::
Float64
>
(
const
std
::
shared_ptr
<
ngraph
::
runtime
::
ParameterizedTensorView
<
ngraph
::
element
::
Float64
>>&
a
,
const
std
::
shared_ptr
<
ngraph
::
runtime
::
ParameterizedTensorView
<
ngraph
::
element
::
Float64
>>&
b
,
ngraph
::
element
::
Float64
::
type
rtol
,
ngraph
::
element
::
Float64
::
type
atol
);
template
<
typename
T
>
bool
ngraph
::
runtime
::
all_close
(
const
std
::
vector
<
T
>&
a
,
const
std
::
vector
<
T
>&
b
,
T
rtol
,
T
atol
)
{
assert
(
a
.
size
()
==
b
.
size
());
for
(
size_t
i
=
0
;
i
<
a
.
size
();
++
i
)
{
if
(
std
::
abs
(
a
[
i
]
-
b
[
i
])
>
atol
+
rtol
*
std
::
abs
(
b
[
i
]))
{
return
false
;
}
}
return
true
;
}
template
bool
ngraph
::
runtime
::
all_close
<
float
>
(
const
std
::
vector
<
float
>&
a
,
const
std
::
vector
<
float
>&
b
,
float
rtol
,
float
atol
);
template
bool
ngraph
::
runtime
::
all_close
<
double
>
(
const
std
::
vector
<
double
>&
a
,
const
std
::
vector
<
double
>&
b
,
double
rtol
,
double
atol
);
src/ngraph/runtime/utils.hpp
View file @
63fdc66e
...
...
@@ -37,5 +37,49 @@ namespace ngraph
/// @brief Framework constructor of a tuple from a sequence of values.
std
::
shared_ptr
<
ngraph
::
runtime
::
Tuple
>
make_tuple
(
const
std
::
vector
<
std
::
shared_ptr
<
ngraph
::
runtime
::
Value
>>&
elements
);
/// @brief Same as numpy.allclose
/// @param a First tensor to compare
/// @param b Second tensor to compare
/// @param rtol Relative tolerance
/// @param atol Absolute tolerance
/// Returns true if shapes match and for all elements, |a_i-b_i| <= atol + rtol*|b_i|.
template
<
typename
ET
>
bool
all_close
(
const
std
::
shared_ptr
<
ngraph
::
runtime
::
ParameterizedTensorView
<
ET
>>&
a
,
const
std
::
shared_ptr
<
ngraph
::
runtime
::
ParameterizedTensorView
<
ET
>>&
b
,
typename
ET
::
type
rtol
=
1e-5
f
,
typename
ET
::
type
atol
=
1e-8
f
);
extern
template
bool
ngraph
::
runtime
::
all_close
<
ngraph
::
element
::
Float32
>
(
const
std
::
shared_ptr
<
ngraph
::
runtime
::
ParameterizedTensorView
<
ngraph
::
element
::
Float32
>>&
a
,
const
std
::
shared_ptr
<
ngraph
::
runtime
::
ParameterizedTensorView
<
ngraph
::
element
::
Float32
>>&
b
,
ngraph
::
element
::
Float32
::
type
rtol
,
ngraph
::
element
::
Float32
::
type
atol
);
extern
template
bool
ngraph
::
runtime
::
all_close
<
ngraph
::
element
::
Float64
>
(
const
std
::
shared_ptr
<
ngraph
::
runtime
::
ParameterizedTensorView
<
ngraph
::
element
::
Float64
>>&
a
,
const
std
::
shared_ptr
<
ngraph
::
runtime
::
ParameterizedTensorView
<
ngraph
::
element
::
Float64
>>&
b
,
ngraph
::
element
::
Float64
::
type
rtol
,
ngraph
::
element
::
Float64
::
type
atol
);
template
<
typename
T
>
bool
all_close
(
const
std
::
vector
<
T
>&
a
,
const
std
::
vector
<
T
>&
b
,
T
rtol
=
1e-5
f
,
T
atol
=
1e-8
f
);
extern
template
bool
ngraph
::
runtime
::
all_close
<
float
>
(
const
std
::
vector
<
float
>&
a
,
const
std
::
vector
<
float
>&
b
,
float
rtol
,
float
atol
);
extern
template
bool
ngraph
::
runtime
::
all_close
<
double
>
(
const
std
::
vector
<
double
>&
a
,
const
std
::
vector
<
double
>&
b
,
double
rtol
,
double
atol
);
}
}
src/ngraph/types/element_type.hpp
View file @
63fdc66e
...
...
@@ -153,6 +153,9 @@ namespace ngraph
NGRAPH_DEFINE_TRAITED_TYPE_NAME
(
float
)
using
Float32
=
TraitedType
<
float
>
;
NGRAPH_DEFINE_TRAITED_TYPE_NAME
(
double
)
using
Float64
=
TraitedType
<
double
>
;
NGRAPH_DEFINE_TRAITED_TYPE_NAME
(
int8_t
)
using
Int8
=
TraitedType
<
int8_t
>
;
...
...
test/util.cpp
View file @
63fdc66e
...
...
@@ -18,6 +18,8 @@
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/runtime/utils.hpp"
#include "ngraph/util.hpp"
using
namespace
std
;
...
...
@@ -169,3 +171,25 @@ TEST(util, reduce)
EXPECT_EQ
(
actual
,
720
);
}
}
TEST
(
util
,
all_close
)
{
auto
manager
=
runtime
::
Manager
::
get
(
"NGVM"
);
auto
backend
=
manager
->
allocate_backend
();
// Create some tensors for input/output
auto
a
=
backend
->
make_parameterized_tensor_view
<
element
::
Float32
>
(
runtime
::
NDArray
<
float
,
2
>
({{
1
,
2
,
3
},
{
3
,
4
,
5
}}));
auto
b
=
backend
->
make_parameterized_tensor_view
<
element
::
Float32
>
(
runtime
::
NDArray
<
float
,
2
>
({{
1
,
2
,
3
},
{
3
,
4
,
5
}}));
EXPECT_TRUE
(
ngraph
::
runtime
::
all_close
(
a
,
b
));
auto
c
=
backend
->
make_parameterized_tensor_view
<
element
::
Float32
>
(
runtime
::
NDArray
<
float
,
2
>
({{
1.1
f
,
2
,
3
},
{
3
,
4
,
5
}}));
EXPECT_FALSE
(
ngraph
::
runtime
::
all_close
(
c
,
a
,
0
,
.05
f
));
EXPECT_TRUE
(
ngraph
::
runtime
::
all_close
(
c
,
a
,
0
,
.11
f
));
EXPECT_FALSE
(
ngraph
::
runtime
::
all_close
(
c
,
a
,
.05
f
,
0
));
EXPECT_TRUE
(
ngraph
::
runtime
::
all_close
(
c
,
a
,
.11
f
,
0
));
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment