Commit 785c1ce7 authored by Adam Rogowiec's avatar Adam Rogowiec Committed by Robert Kimball

[Py] Enable retrieve data from constant node. (#1214)

* Enable retrieving data from Constant in python.

* Test on wide value range.
parent 0c721561
......@@ -23,6 +23,8 @@ Low level wrappers for the nGraph c++ api in ngraph::op.
import sys
import six
import numpy as np
# workaround to load the libngraph.so with RTLD_GLOBAL
if six.PY3:
import os
......@@ -47,6 +49,15 @@ from _pyngraph.op import Broadcast
from _pyngraph.op import Ceiling
from _pyngraph.op import Concat
from _pyngraph.op import Constant
""" Retrieve Constant inner data.
Internally uses PyBind11 Numpy's buffer protocol.
:return Numpy array containing internally stored constant data.
"""
Constant.get_data = lambda self: np.array(self, copy=True)
from _pyngraph.op import Convert
from _pyngraph.op import Convolution
from _pyngraph.op import ConvolutionBackpropData
......
......@@ -14,18 +14,49 @@
* limitations under the License.
*******************************************************************************/
#include <stdexcept>
#include <vector>
#include <pybind11/buffer_info.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
//#include <string>
#include "ngraph/op/constant.hpp"
#include "ngraph/shape.hpp"
#include "pyngraph/ops/constant.hpp"
namespace py = pybind11;
template <typename T>
std::vector<ssize_t> _get_byte_strides(const ngraph::Shape& s)
{
std::vector<ssize_t> byte_strides;
std::vector<size_t> element_strides = ngraph::row_major_strides(s);
for (auto v : element_strides)
{
byte_strides.push_back(static_cast<ssize_t>(v) * sizeof(T));
}
return byte_strides;
}
template <typename T>
py::buffer_info _get_buffer_info(const ngraph::op::Constant& c)
{
ngraph::Shape shape = c.get_shape();
return py::buffer_info(
const_cast<void*>(c.get_data_ptr()), /* Pointer to buffer */
static_cast<ssize_t>(c.get_element_type().size()), /* Size of one scalar */
py::format_descriptor<T>::format(), /* Python struct-style format
descriptor */
static_cast<ssize_t>(shape.size()), /* Number of dimensions */
std::vector<ssize_t>{shape.begin(), shape.end()}, /* Buffer dimensions */
_get_byte_strides<T>(shape) /* Strides (in bytes) for each index */
);
}
void regclass_pyngraph_op_Constant(py::module m)
{
py::class_<ngraph::op::Constant, std::shared_ptr<ngraph::op::Constant>, ngraph::Node> constant(
m, "Constant");
m, "Constant", py::buffer_protocol());
constant.doc() = "ngraph.impl.op.Constant wraps ngraph::op::Constant";
constant.def(
py::init<const ngraph::element::Type&, const ngraph::Shape&, const std::vector<char>&>());
......@@ -56,4 +87,58 @@ void regclass_pyngraph_op_Constant(py::module m)
constant.def(py::init<const ngraph::element::Type&,
const ngraph::Shape&,
const std::vector<uint64_t>&>());
constant.def("get_value_strings", &ngraph::op::Constant::get_value_strings);
// Provide buffer access
constant.def_buffer([](const ngraph::op::Constant& self) -> py::buffer_info {
auto element_type = self.get_element_type();
if (element_type == ngraph::element::boolean)
{
return _get_buffer_info<char>(self);
}
else if (element_type == ngraph::element::f32)
{
return _get_buffer_info<float>(self);
}
else if (element_type == ngraph::element::f64)
{
return _get_buffer_info<double>(self);
}
else if (element_type == ngraph::element::i8)
{
return _get_buffer_info<int8_t>(self);
}
else if (element_type == ngraph::element::i16)
{
return _get_buffer_info<int16_t>(self);
}
else if (element_type == ngraph::element::i32)
{
return _get_buffer_info<int32_t>(self);
}
else if (element_type == ngraph::element::i64)
{
return _get_buffer_info<int64_t>(self);
}
else if (element_type == ngraph::element::u8)
{
return _get_buffer_info<uint8_t>(self);
}
else if (element_type == ngraph::element::u16)
{
return _get_buffer_info<uint16_t>(self);
}
else if (element_type == ngraph::element::u32)
{
return _get_buffer_info<uint32_t>(self);
}
else if (element_type == ngraph::element::u64)
{
return _get_buffer_info<uint64_t>(self);
}
else
{
throw std::runtime_error("Unsupproted data type!");
}
});
}
......@@ -171,3 +171,56 @@ def test_bad_data_shape():
value_b = np.array([[5, 6], [7, 8]], dtype=np.float32)
with pytest.raises(UserInputError):
computation(value_a, value_b)
def test_constant_get_data_bool():
input_data = np.array([True, False, False, True])
node = ng.constant(input_data, dtype=np.bool)
retrieved_data = node.get_data()
assert np.allclose(input_data, retrieved_data)
@pytest.mark.parametrize('data_type', [
np.float32,
np.float64,
])
def test_constant_get_data_floating_point(data_type):
np.random.seed(133391)
input_data = np.random.randn(2, 3, 4).astype(data_type)
min_value = -1.e20
max_value = 1.e20
input_data = min_value + input_data * max_value * data_type(2)
node = ng.constant(input_data, dtype=data_type)
retrieved_data = node.get_data()
assert np.allclose(input_data, retrieved_data)
@pytest.mark.parametrize('data_type', [
np.int64,
np.int32,
np.int16,
np.int8,
])
def test_constant_get_data_signed_integer(data_type):
np.random.seed(133391)
input_data = np.random.randint(np.iinfo(data_type).min, np.iinfo(data_type).max,
[2, 3, 4]).astype(data_type)
node = ng.constant(input_data, dtype=data_type)
retrieved_data = node.get_data()
assert np.allclose(input_data, retrieved_data)
@pytest.mark.parametrize('data_type', [
np.uint64,
np.uint32,
np.uint16,
np.uint8,
])
def test_constant_get_data_unsigned_integer(data_type):
np.random.seed(133391)
input_data = np.random.randn(2, 3, 4).astype(data_type)
input_data = (np.iinfo(data_type).min + input_data * np.iinfo(data_type).max +
input_data * np.iinfo(data_type).max)
node = ng.constant(input_data, dtype=data_type)
retrieved_data = node.get_data()
assert np.allclose(input_data, retrieved_data)
......@@ -34,7 +34,7 @@ def get_runtime():
def run_op_node(input_data, op_fun, *args):
"""Run computation on node performing `op_fun`.
`op_fun` have to needs to accept a node as an argument.
`op_fun` has to accept a node as an argument.
:param input_data: The input data for performed computation.
:param op_fun: The function handler for operation we want to carry out.
......@@ -62,7 +62,7 @@ def run_op_node(input_data, op_fun, *args):
def run_op_numeric_data(input_data, op_fun, *args):
"""Run computation on node performing `op_fun`.
`op_fun` have to accept a scalar or an array.
`op_fun` has to accept a scalar or an array.
:param input_data: The input data for performed computation.
:param op_fun: The function handler for operation we want to carry out.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment