Commit 63fdc66e authored by Scott Cyphers's avatar Scott Cyphers

Add all_close for comparing tensors.

parent d79e0353
......@@ -44,3 +44,21 @@ size_t DenseTensorViewLayout::get_index_offset(const std::vector<size_t>& indice
}
return result;
}
bool DenseTensorViewLayout::operator==(const TensorViewLayout& other) const
{
const DenseTensorViewLayout* p_other = dynamic_cast<const DenseTensorViewLayout*>(&other);
if (nullptr == p_other)
return false;
if (get_element_type() != p_other->get_element_type())
return false;
if (m_strides != p_other->m_strides)
return false;
if (m_offset != p_other->m_offset)
return false;
return true;
}
......@@ -41,9 +41,11 @@ namespace ngraph
virtual size_t get_index_offset(const std::vector<size_t>& indices) override;
const Strides& get_strides() const { return m_strides; }
virtual bool operator==(const TensorViewLayout& other) const override;
protected:
Strides m_strides;
size_t m_offset;
size_t m_offset{0};
size_t m_size;
};
}
......
......@@ -59,6 +59,9 @@ namespace ngraph
/// Where this view is located in the buffer.
const BufferPos& get_buffer_pos() const { return m_buffer_pos; }
BufferPos& get_buffer_pos() { return m_buffer_pos; }
/// @brief Return true if this and other have the same element interpretation
virtual bool operator==(const TensorViewLayout& other) const = 0;
bool operator!=(const TensorViewLayout& other) const { return !(*this == other); }
protected:
std::shared_ptr<const TensorViewType> m_tensor_view_type;
BufferPos m_buffer_pos;
......
......@@ -38,3 +38,9 @@ const ngraph::Shape& TensorView::get_shape() const
{
return m_descriptor->get_tensor_view_type()->get_shape();
}
std::shared_ptr<ngraph::descriptor::layout::TensorViewLayout>
TensorView::get_tensor_view_layout() const
{
return m_descriptor->get_tensor_view_layout();
}
......@@ -60,6 +60,9 @@ namespace ngraph
const ngraph::Shape& get_shape() const;
std::shared_ptr<ngraph::descriptor::layout::TensorViewLayout>
get_tensor_view_layout() const;
/// @brief Write bytes directly into the tensor
/// @param p Pointer to source of data
/// @param tensor_offset Offset into tensor storage to begin writing. Must be element-aligned.
......
......@@ -12,6 +12,9 @@
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <cassert>
#include <cmath>
#include "ngraph/runtime/utils.hpp"
std::shared_ptr<ngraph::runtime::Tuple> ngraph::runtime::make_tuple(
......@@ -19,3 +22,58 @@ std::shared_ptr<ngraph::runtime::Tuple> ngraph::runtime::make_tuple(
{
return std::make_shared<ngraph::runtime::Tuple>(elements);
}
template <typename ET>
bool ngraph::runtime::all_close(
const std::shared_ptr<ngraph::runtime::ParameterizedTensorView<ET>>& a,
const std::shared_ptr<ngraph::runtime::ParameterizedTensorView<ET>>& b,
typename ET::type rtol,
typename ET::type atol)
{
// Check that the layouts are compatible
if (*a->get_tensor_view_layout() != *b->get_tensor_view_layout())
{
throw ngraph_error("Cannot compare tensors with different layouts");
}
if (a->get_shape() != b->get_shape())
return false;
return ngraph::runtime::all_close(a->get_vector(), b->get_vector(), rtol, atol);
}
template bool ngraph::runtime::all_close<ngraph::element::Float32>(
const std::shared_ptr<ngraph::runtime::ParameterizedTensorView<ngraph::element::Float32>>& a,
const std::shared_ptr<ngraph::runtime::ParameterizedTensorView<ngraph::element::Float32>>& b,
ngraph::element::Float32::type rtol,
ngraph::element::Float32::type atol);
template bool ngraph::runtime::all_close<ngraph::element::Float64>(
const std::shared_ptr<ngraph::runtime::ParameterizedTensorView<ngraph::element::Float64>>& a,
const std::shared_ptr<ngraph::runtime::ParameterizedTensorView<ngraph::element::Float64>>& b,
ngraph::element::Float64::type rtol,
ngraph::element::Float64::type atol);
template <typename T>
bool ngraph::runtime::all_close(const std::vector<T>& a, const std::vector<T>& b, T rtol, T atol)
{
assert(a.size() == b.size());
for (size_t i = 0; i < a.size(); ++i)
{
if (std::abs(a[i] - b[i]) > atol + rtol * std::abs(b[i]))
{
return false;
}
}
return true;
}
template bool ngraph::runtime::all_close<float>(const std::vector<float>& a,
const std::vector<float>& b,
float rtol,
float atol);
template bool ngraph::runtime::all_close<double>(const std::vector<double>& a,
const std::vector<double>& b,
double rtol,
double atol);
......@@ -37,5 +37,49 @@ namespace ngraph
/// @brief Framework constructor of a tuple from a sequence of values.
std::shared_ptr<ngraph::runtime::Tuple>
make_tuple(const std::vector<std::shared_ptr<ngraph::runtime::Value>>& elements);
/// @brief Same as numpy.allclose
/// @param a First tensor to compare
/// @param b Second tensor to compare
/// @param rtol Relative tolerance
/// @param atol Absolute tolerance
/// Returns true if shapes match and for all elements, |a_i-b_i| <= atol + rtol*|b_i|.
template <typename ET>
bool all_close(const std::shared_ptr<ngraph::runtime::ParameterizedTensorView<ET>>& a,
const std::shared_ptr<ngraph::runtime::ParameterizedTensorView<ET>>& b,
typename ET::type rtol = 1e-5f,
typename ET::type atol = 1e-8f);
extern template bool ngraph::runtime::all_close<ngraph::element::Float32>(
const std::shared_ptr<
ngraph::runtime::ParameterizedTensorView<ngraph::element::Float32>>& a,
const std::shared_ptr<
ngraph::runtime::ParameterizedTensorView<ngraph::element::Float32>>& b,
ngraph::element::Float32::type rtol,
ngraph::element::Float32::type atol);
extern template bool ngraph::runtime::all_close<ngraph::element::Float64>(
const std::shared_ptr<
ngraph::runtime::ParameterizedTensorView<ngraph::element::Float64>>& a,
const std::shared_ptr<
ngraph::runtime::ParameterizedTensorView<ngraph::element::Float64>>& b,
ngraph::element::Float64::type rtol,
ngraph::element::Float64::type atol);
template <typename T>
bool all_close(const std::vector<T>& a,
const std::vector<T>& b,
T rtol = 1e-5f,
T atol = 1e-8f);
extern template bool ngraph::runtime::all_close<float>(const std::vector<float>& a,
const std::vector<float>& b,
float rtol,
float atol);
extern template bool ngraph::runtime::all_close<double>(const std::vector<double>& a,
const std::vector<double>& b,
double rtol,
double atol);
}
}
......@@ -153,6 +153,9 @@ namespace ngraph
NGRAPH_DEFINE_TRAITED_TYPE_NAME(float)
using Float32 = TraitedType<float>;
NGRAPH_DEFINE_TRAITED_TYPE_NAME(double)
using Float64 = TraitedType<double>;
NGRAPH_DEFINE_TRAITED_TYPE_NAME(int8_t)
using Int8 = TraitedType<int8_t>;
......
......@@ -18,6 +18,8 @@
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/runtime/utils.hpp"
#include "ngraph/util.hpp"
using namespace std;
......@@ -169,3 +171,25 @@ TEST(util, reduce)
EXPECT_EQ(actual, 720);
}
}
TEST(util, all_close)
{
auto manager = runtime::Manager::get("NGVM");
auto backend = manager->allocate_backend();
// Create some tensors for input/output
auto a = backend->make_parameterized_tensor_view<element::Float32>(
runtime::NDArray<float, 2>({{1, 2, 3}, {3, 4, 5}}));
auto b = backend->make_parameterized_tensor_view<element::Float32>(
runtime::NDArray<float, 2>({{1, 2, 3}, {3, 4, 5}}));
EXPECT_TRUE(ngraph::runtime::all_close(a, b));
auto c = backend->make_parameterized_tensor_view<element::Float32>(
runtime::NDArray<float, 2>({{1.1f, 2, 3}, {3, 4, 5}}));
EXPECT_FALSE(ngraph::runtime::all_close(c, a, 0, .05f));
EXPECT_TRUE(ngraph::runtime::all_close(c, a, 0, .11f));
EXPECT_FALSE(ngraph::runtime::all_close(c, a, .05f, 0));
EXPECT_TRUE(ngraph::runtime::all_close(c, a, .11f, 0));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment