Commit 870ab827 authored by Michał Karzyński's avatar Michał Karzyński Committed by Scott Cyphers

[Py] Add __repr__ to Strides and CoordDiff (#1291)

* [Py] Add __repr__ to Strides and CoordDiff

* Apply clang-format

* Repr fix

* Apply clang-format
parent 5c56923a
......@@ -13,10 +13,13 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <iterator>
#include <sstream>
#include <string>
#include "ngraph/coordinate_diff.hpp" //ngraph::CoordinateDiff
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/coordinate_diff.hpp" //ngraph::CoordinateDiff
#include "pyngraph/coordinate_diff.hpp"
namespace py = pybind11;
......@@ -29,4 +32,17 @@ void regclass_pyngraph_CoordinateDiff(py::module m)
coordinate_diff.def(py::init<const std::initializer_list<ptrdiff_t>&>());
coordinate_diff.def(py::init<const std::vector<ptrdiff_t>&>());
coordinate_diff.def(py::init<const ngraph::CoordinateDiff&>());
coordinate_diff.def("__str__", [](const ngraph::CoordinateDiff& self) -> std::string {
std::stringstream stringstream;
std::copy(self.begin(), self.end(), std::ostream_iterator<int>(stringstream, ", "));
std::string string = stringstream.str();
return string.substr(0, string.size() - 2);
});
coordinate_diff.def("__repr__", [](const ngraph::CoordinateDiff& self) -> std::string {
std::string class_name = py::cast(self).get_type().attr("__name__").cast<std::string>();
std::string shape_str = py::cast(self).attr("__str__")().cast<std::string>();
return "<" + class_name + ": (" + shape_str + ")>";
});
}
......@@ -13,10 +13,13 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <iterator>
#include <sstream>
#include <string>
#include "ngraph/strides.hpp" //ngraph::Strides
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/strides.hpp" //ngraph::Strides
#include "pyngraph/strides.hpp"
namespace py = pybind11;
......@@ -28,4 +31,17 @@ void regclass_pyngraph_Strides(py::module m)
strides.def(py::init<const std::initializer_list<size_t>&>());
strides.def(py::init<const std::vector<size_t>&>());
strides.def(py::init<const ngraph::Strides&>());
strides.def("__str__", [](const ngraph::Strides& self) -> std::string {
std::stringstream stringstream;
std::copy(self.begin(), self.end(), std::ostream_iterator<int>(stringstream, ", "));
std::string string = stringstream.str();
return string.substr(0, string.size() - 2);
});
strides.def("__repr__", [](const ngraph::Strides& self) -> std::string {
std::string class_name = py::cast(self).get_type().attr("__name__").cast<std::string>();
std::string shape_str = py::cast(self).attr("__str__")().cast<std::string>();
return "<" + class_name + ": (" + shape_str + ")>";
});
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment