Commit 63fdc66e authored by Scott Cyphers's avatar Scott Cyphers

Add all_close for comparing tensors.

parent d79e0353
...@@ -44,3 +44,21 @@ size_t DenseTensorViewLayout::get_index_offset(const std::vector<size_t>& indice ...@@ -44,3 +44,21 @@ size_t DenseTensorViewLayout::get_index_offset(const std::vector<size_t>& indice
} }
return result; return result;
} }
bool DenseTensorViewLayout::operator==(const TensorViewLayout& other) const
{
const DenseTensorViewLayout* p_other = dynamic_cast<const DenseTensorViewLayout*>(&other);
if (nullptr == p_other)
return false;
if (get_element_type() != p_other->get_element_type())
return false;
if (m_strides != p_other->m_strides)
return false;
if (m_offset != p_other->m_offset)
return false;
return true;
}
...@@ -41,9 +41,11 @@ namespace ngraph ...@@ -41,9 +41,11 @@ namespace ngraph
virtual size_t get_index_offset(const std::vector<size_t>& indices) override; virtual size_t get_index_offset(const std::vector<size_t>& indices) override;
const Strides& get_strides() const { return m_strides; } const Strides& get_strides() const { return m_strides; }
virtual bool operator==(const TensorViewLayout& other) const override;
protected: protected:
Strides m_strides; Strides m_strides;
size_t m_offset; size_t m_offset{0};
size_t m_size; size_t m_size;
}; };
} }
......
...@@ -59,6 +59,9 @@ namespace ngraph ...@@ -59,6 +59,9 @@ namespace ngraph
/// Where this view is located in the buffer. /// Where this view is located in the buffer.
const BufferPos& get_buffer_pos() const { return m_buffer_pos; } const BufferPos& get_buffer_pos() const { return m_buffer_pos; }
BufferPos& get_buffer_pos() { return m_buffer_pos; } BufferPos& get_buffer_pos() { return m_buffer_pos; }
/// @brief Return true if this and other have the same element interpretation
virtual bool operator==(const TensorViewLayout& other) const = 0;
bool operator!=(const TensorViewLayout& other) const { return !(*this == other); }
protected: protected:
std::shared_ptr<const TensorViewType> m_tensor_view_type; std::shared_ptr<const TensorViewType> m_tensor_view_type;
BufferPos m_buffer_pos; BufferPos m_buffer_pos;
......
...@@ -38,3 +38,9 @@ const ngraph::Shape& TensorView::get_shape() const ...@@ -38,3 +38,9 @@ const ngraph::Shape& TensorView::get_shape() const
{ {
return m_descriptor->get_tensor_view_type()->get_shape(); return m_descriptor->get_tensor_view_type()->get_shape();
} }
std::shared_ptr<ngraph::descriptor::layout::TensorViewLayout>
TensorView::get_tensor_view_layout() const
{
return m_descriptor->get_tensor_view_layout();
}
...@@ -60,6 +60,9 @@ namespace ngraph ...@@ -60,6 +60,9 @@ namespace ngraph
const ngraph::Shape& get_shape() const; const ngraph::Shape& get_shape() const;
std::shared_ptr<ngraph::descriptor::layout::TensorViewLayout>
get_tensor_view_layout() const;
/// @brief Write bytes directly into the tensor /// @brief Write bytes directly into the tensor
/// @param p Pointer to source of data /// @param p Pointer to source of data
/// @param tensor_offset Offset into tensor storage to begin writing. Must be element-aligned. /// @param tensor_offset Offset into tensor storage to begin writing. Must be element-aligned.
......
...@@ -12,6 +12,9 @@ ...@@ -12,6 +12,9 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include <cassert>
#include <cmath>
#include "ngraph/runtime/utils.hpp" #include "ngraph/runtime/utils.hpp"
std::shared_ptr<ngraph::runtime::Tuple> ngraph::runtime::make_tuple( std::shared_ptr<ngraph::runtime::Tuple> ngraph::runtime::make_tuple(
...@@ -19,3 +22,58 @@ std::shared_ptr<ngraph::runtime::Tuple> ngraph::runtime::make_tuple( ...@@ -19,3 +22,58 @@ std::shared_ptr<ngraph::runtime::Tuple> ngraph::runtime::make_tuple(
{ {
return std::make_shared<ngraph::runtime::Tuple>(elements); return std::make_shared<ngraph::runtime::Tuple>(elements);
} }
template <typename ET>
bool ngraph::runtime::all_close(
const std::shared_ptr<ngraph::runtime::ParameterizedTensorView<ET>>& a,
const std::shared_ptr<ngraph::runtime::ParameterizedTensorView<ET>>& b,
typename ET::type rtol,
typename ET::type atol)
{
// Check that the layouts are compatible
if (*a->get_tensor_view_layout() != *b->get_tensor_view_layout())
{
throw ngraph_error("Cannot compare tensors with different layouts");
}
if (a->get_shape() != b->get_shape())
return false;
return ngraph::runtime::all_close(a->get_vector(), b->get_vector(), rtol, atol);
}
template bool ngraph::runtime::all_close<ngraph::element::Float32>(
const std::shared_ptr<ngraph::runtime::ParameterizedTensorView<ngraph::element::Float32>>& a,
const std::shared_ptr<ngraph::runtime::ParameterizedTensorView<ngraph::element::Float32>>& b,
ngraph::element::Float32::type rtol,
ngraph::element::Float32::type atol);
template bool ngraph::runtime::all_close<ngraph::element::Float64>(
const std::shared_ptr<ngraph::runtime::ParameterizedTensorView<ngraph::element::Float64>>& a,
const std::shared_ptr<ngraph::runtime::ParameterizedTensorView<ngraph::element::Float64>>& b,
ngraph::element::Float64::type rtol,
ngraph::element::Float64::type atol);
template <typename T>
bool ngraph::runtime::all_close(const std::vector<T>& a, const std::vector<T>& b, T rtol, T atol)
{
assert(a.size() == b.size());
for (size_t i = 0; i < a.size(); ++i)
{
if (std::abs(a[i] - b[i]) > atol + rtol * std::abs(b[i]))
{
return false;
}
}
return true;
}
template bool ngraph::runtime::all_close<float>(const std::vector<float>& a,
const std::vector<float>& b,
float rtol,
float atol);
template bool ngraph::runtime::all_close<double>(const std::vector<double>& a,
const std::vector<double>& b,
double rtol,
double atol);
...@@ -37,5 +37,49 @@ namespace ngraph ...@@ -37,5 +37,49 @@ namespace ngraph
/// @brief Framework constructor of a tuple from a sequence of values. /// @brief Framework constructor of a tuple from a sequence of values.
std::shared_ptr<ngraph::runtime::Tuple> std::shared_ptr<ngraph::runtime::Tuple>
make_tuple(const std::vector<std::shared_ptr<ngraph::runtime::Value>>& elements); make_tuple(const std::vector<std::shared_ptr<ngraph::runtime::Value>>& elements);
/// @brief Same as numpy.allclose
/// @param a First tensor to compare
/// @param b Second tensor to compare
/// @param rtol Relative tolerance
/// @param atol Absolute tolerance
/// Returns true if shapes match and for all elements, |a_i-b_i| <= atol + rtol*|b_i|.
template <typename ET>
bool all_close(const std::shared_ptr<ngraph::runtime::ParameterizedTensorView<ET>>& a,
const std::shared_ptr<ngraph::runtime::ParameterizedTensorView<ET>>& b,
typename ET::type rtol = 1e-5f,
typename ET::type atol = 1e-8f);
extern template bool ngraph::runtime::all_close<ngraph::element::Float32>(
const std::shared_ptr<
ngraph::runtime::ParameterizedTensorView<ngraph::element::Float32>>& a,
const std::shared_ptr<
ngraph::runtime::ParameterizedTensorView<ngraph::element::Float32>>& b,
ngraph::element::Float32::type rtol,
ngraph::element::Float32::type atol);
extern template bool ngraph::runtime::all_close<ngraph::element::Float64>(
const std::shared_ptr<
ngraph::runtime::ParameterizedTensorView<ngraph::element::Float64>>& a,
const std::shared_ptr<
ngraph::runtime::ParameterizedTensorView<ngraph::element::Float64>>& b,
ngraph::element::Float64::type rtol,
ngraph::element::Float64::type atol);
template <typename T>
bool all_close(const std::vector<T>& a,
const std::vector<T>& b,
T rtol = 1e-5f,
T atol = 1e-8f);
extern template bool ngraph::runtime::all_close<float>(const std::vector<float>& a,
const std::vector<float>& b,
float rtol,
float atol);
extern template bool ngraph::runtime::all_close<double>(const std::vector<double>& a,
const std::vector<double>& b,
double rtol,
double atol);
} }
} }
...@@ -153,6 +153,9 @@ namespace ngraph ...@@ -153,6 +153,9 @@ namespace ngraph
NGRAPH_DEFINE_TRAITED_TYPE_NAME(float) NGRAPH_DEFINE_TRAITED_TYPE_NAME(float)
using Float32 = TraitedType<float>; using Float32 = TraitedType<float>;
NGRAPH_DEFINE_TRAITED_TYPE_NAME(double)
using Float64 = TraitedType<double>;
NGRAPH_DEFINE_TRAITED_TYPE_NAME(int8_t) NGRAPH_DEFINE_TRAITED_TYPE_NAME(int8_t)
using Int8 = TraitedType<int8_t>; using Int8 = TraitedType<int8_t>;
......
...@@ -18,6 +18,8 @@ ...@@ -18,6 +18,8 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/runtime/utils.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
using namespace std; using namespace std;
...@@ -169,3 +171,25 @@ TEST(util, reduce) ...@@ -169,3 +171,25 @@ TEST(util, reduce)
EXPECT_EQ(actual, 720); EXPECT_EQ(actual, 720);
} }
} }
TEST(util, all_close)
{
auto manager = runtime::Manager::get("NGVM");
auto backend = manager->allocate_backend();
// Create some tensors for input/output
auto a = backend->make_parameterized_tensor_view<element::Float32>(
runtime::NDArray<float, 2>({{1, 2, 3}, {3, 4, 5}}));
auto b = backend->make_parameterized_tensor_view<element::Float32>(
runtime::NDArray<float, 2>({{1, 2, 3}, {3, 4, 5}}));
EXPECT_TRUE(ngraph::runtime::all_close(a, b));
auto c = backend->make_parameterized_tensor_view<element::Float32>(
runtime::NDArray<float, 2>({{1.1f, 2, 3}, {3, 4, 5}}));
EXPECT_FALSE(ngraph::runtime::all_close(c, a, 0, .05f));
EXPECT_TRUE(ngraph::runtime::all_close(c, a, 0, .11f));
EXPECT_FALSE(ngraph::runtime::all_close(c, a, .05f, 0));
EXPECT_TRUE(ngraph::runtime::all_close(c, a, .11f, 0));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment