Commit 4e2715b5 authored by Michał Karzyński's avatar Michał Karzyński Committed by Robert Kimball

[Py] Use get_friendly_name for object names (#968)

* [Py] Use get_friendly_name for object names
parent 8508410f
...@@ -18,24 +18,23 @@ ...@@ -18,24 +18,23 @@
import numpy as np import numpy as np
import ngraph as ng import ngraph as ng
shape = [2, 2] A = ng.parameter(shape=[2, 2], name='A', dtype=np.float32)
A = ng.parameter(shape, name='A') B = ng.parameter(shape=[2, 2], name='B')
B = ng.parameter(shape, name='B') C = ng.parameter(shape=[2, 2], name='C')
C = ng.parameter(shape, name='C')
# >>> print(A) # >>> print(A)
# <Parameter: 'A' (2, 2, float)> # <Parameter: 'A' ([2, 2], float)>
model = (A + B) * C model = (A + B) * C
# >>> print(model) # >>> print(model)
# <Node: 'Multiply_6'> # <Multiply: 'Multiply_14' ([2, 2])>
runtime = ng.runtime(backend_name='INTERPRETER') runtime = ng.runtime(backend_name='CPU')
# >>> print(runtime) # >>> print(runtime)
# <Runtime: Manager='INTERPRETER'> # <Runtime: Backend='CPU'>
computation = runtime.computation(model, A, B, C) computation = runtime.computation(model, A, B, C)
# >>> print(computation) # >>> print(computation)
# <Computation: Multiply_6(A, B, C)> # <Computation: Multiply_14(A, B, C)>
value_a = np.array([[1, 2], [3, 4]], dtype=np.float32) value_a = np.array([[1, 2], [3, 4]], dtype=np.float32)
value_b = np.array([[5, 6], [7, 8]], dtype=np.float32) value_b = np.array([[5, 6], [7, 8]], dtype=np.float32)
......
...@@ -63,8 +63,9 @@ def as_elementwise_compatible_nodes(*input_values): # type: (*NodeInput) -> Lis ...@@ -63,8 +63,9 @@ def as_elementwise_compatible_nodes(*input_values): # type: (*NodeInput) -> Lis
if len(shapes) > 1: if len(shapes) > 1:
log.warning('More than one different shape in input nodes %s.', input_nodes) log.warning('More than one different shape in input nodes %s.', input_nodes)
types = {node.get_element_type() for node in input_nodes} types = [node.get_element_type() for node in input_nodes]
if len(types) > 1: unique_types = {repr(type) for type in types}
if len(unique_types) > 1:
log.warning('More than one different data type in input nodes %s.', input_nodes) log.warning('More than one different data type in input nodes %s.', input_nodes)
sorted_shapes = sorted(shapes, key=len) sorted_shapes = sorted(shapes, key=len)
......
...@@ -69,6 +69,6 @@ void regclass_pyngraph_Node(py::module m) ...@@ -69,6 +69,6 @@ void regclass_pyngraph_Node(py::module m)
node.def("get_shape", &ngraph::Node::get_shape); node.def("get_shape", &ngraph::Node::get_shape);
node.def("get_argument", &ngraph::Node::get_argument); node.def("get_argument", &ngraph::Node::get_argument);
node.def_property("name", &ngraph::Node::get_name, &ngraph::Node::set_name); node.def_property("name", &ngraph::Node::get_friendly_name, &ngraph::Node::set_name);
node.def_property_readonly("shape", &ngraph::Node::get_shape); node.def_property_readonly("shape", &ngraph::Node::get_shape);
} }
...@@ -33,7 +33,8 @@ void regclass_pyngraph_op_Parameter(py::module m) ...@@ -33,7 +33,8 @@ void regclass_pyngraph_op_Parameter(py::module m)
std::string class_name = py::cast(self).get_type().attr("__name__").cast<std::string>(); std::string class_name = py::cast(self).get_type().attr("__name__").cast<std::string>();
std::string shape = py::cast(self.get_shape()).attr("__str__")().cast<std::string>(); std::string shape = py::cast(self.get_shape()).attr("__str__")().cast<std::string>();
std::string type = self.get_element_type().c_type_string(); std::string type = self.get_element_type().c_type_string();
return "<" + class_name + ": '" + self.get_name() + "' (" + shape + ", " + type + ")>"; return "<" + class_name + ": '" + self.get_friendly_name() + "' (" + shape + ", " + type +
")>";
}); });
parameter.def(py::init<const ngraph::element::Type&, const ngraph::Shape&>()); parameter.def(py::init<const ngraph::element::Type&, const ngraph::Shape&>());
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment