constant.cpp 5.63 KB
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************

#include <pybind11/buffer_info.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include <stdexcept>
#include <vector>

#include "ngraph/op/constant.hpp"
#include "ngraph/shape.hpp"
#include "pyngraph/ops/constant.hpp"

namespace py = pybind11;

template <typename T>
std::vector<ssize_t> _get_byte_strides(const ngraph::Shape& s)
{
    std::vector<ssize_t> byte_strides;
    std::vector<size_t> element_strides = ngraph::row_major_strides(s);
    for (auto v : element_strides)
    {
        byte_strides.push_back(static_cast<ssize_t>(v) * sizeof(T));
    }
    return byte_strides;
}

template <typename T>
py::buffer_info _get_buffer_info(const ngraph::op::Constant& c)
{
    ngraph::Shape shape = c.get_shape();
    return py::buffer_info(
        const_cast<void*>(c.get_data_ptr()),               /* Pointer to buffer */
        static_cast<ssize_t>(c.get_element_type().size()), /* Size of one scalar */
        py::format_descriptor<T>::format(),                /* Python struct-style format
                                                              descriptor */
        static_cast<ssize_t>(shape.size()),                /* Number of dimensions */
        std::vector<ssize_t>{shape.begin(), shape.end()},  /* Buffer dimensions */
        _get_byte_strides<T>(shape)                        /* Strides (in bytes) for each index */
        );
}

void regclass_pyngraph_op_Constant(py::module m)
{
    py::class_<ngraph::op::Constant, std::shared_ptr<ngraph::op::Constant>, ngraph::Node> constant(
        m, "Constant", py::buffer_protocol());
    constant.doc() = "ngraph.impl.op.Constant wraps ngraph::op::Constant";
    constant.def(
        py::init<const ngraph::element::Type&, const ngraph::Shape&, const std::vector<char>&>());
    constant.def(
        py::init<const ngraph::element::Type&, const ngraph::Shape&, const std::vector<float>&>());
    constant.def(
        py::init<const ngraph::element::Type&, const ngraph::Shape&, const std::vector<double>&>());
    constant.def(
        py::init<const ngraph::element::Type&, const ngraph::Shape&, const std::vector<int8_t>&>());
    constant.def(py::init<const ngraph::element::Type&,
                          const ngraph::Shape&,
                          const std::vector<int16_t>&>());
    constant.def(py::init<const ngraph::element::Type&,
                          const ngraph::Shape&,
                          const std::vector<int32_t>&>());
    constant.def(py::init<const ngraph::element::Type&,
                          const ngraph::Shape&,
                          const std::vector<int64_t>&>());
    constant.def(py::init<const ngraph::element::Type&,
                          const ngraph::Shape&,
                          const std::vector<uint8_t>&>());
    constant.def(py::init<const ngraph::element::Type&,
                          const ngraph::Shape&,
                          const std::vector<uint16_t>&>());
    constant.def(py::init<const ngraph::element::Type&,
                          const ngraph::Shape&,
                          const std::vector<uint32_t>&>());
    constant.def(py::init<const ngraph::element::Type&,
                          const ngraph::Shape&,
                          const std::vector<uint64_t>&>());

    constant.def("get_value_strings", &ngraph::op::Constant::get_value_strings);
    // Provide buffer access
    constant.def_buffer([](const ngraph::op::Constant& self) -> py::buffer_info {
        auto element_type = self.get_element_type();
        if (element_type == ngraph::element::boolean)
        {
            return _get_buffer_info<char>(self);
        }
        else if (element_type == ngraph::element::f32)
        {
            return _get_buffer_info<float>(self);
        }
        else if (element_type == ngraph::element::f64)
        {
            return _get_buffer_info<double>(self);
        }
        else if (element_type == ngraph::element::i8)
        {
            return _get_buffer_info<int8_t>(self);
        }
        else if (element_type == ngraph::element::i16)
        {
            return _get_buffer_info<int16_t>(self);
        }
        else if (element_type == ngraph::element::i32)
        {
            return _get_buffer_info<int32_t>(self);
        }
        else if (element_type == ngraph::element::i64)
        {
            return _get_buffer_info<int64_t>(self);
        }
        else if (element_type == ngraph::element::u8)
        {
            return _get_buffer_info<uint8_t>(self);
        }
        else if (element_type == ngraph::element::u16)
        {
            return _get_buffer_info<uint16_t>(self);
        }
        else if (element_type == ngraph::element::u32)
        {
            return _get_buffer_info<uint32_t>(self);
        }
        else if (element_type == ngraph::element::u64)
        {
            return _get_buffer_info<uint64_t>(self);
        }
        else
        {
            throw std::runtime_error("Unsupproted data type!");
        }
    });
}