Commit 4e2715b5 authored by Michał Karzyński's avatar Michał Karzyński Committed by Robert Kimball

[Py] Use get_friendly_name for object names (#968)

* [Py] Use get_friendly_name for object names
parent 8508410f
......@@ -18,24 +18,23 @@
import numpy as np
import ngraph as ng
shape = [2, 2]
A = ng.parameter(shape, name='A')
B = ng.parameter(shape, name='B')
C = ng.parameter(shape, name='C')
A = ng.parameter(shape=[2, 2], name='A', dtype=np.float32)
B = ng.parameter(shape=[2, 2], name='B')
C = ng.parameter(shape=[2, 2], name='C')
# >>> print(A)
# <Parameter: 'A' (2, 2, float)>
# <Parameter: 'A' ([2, 2], float)>
model = (A + B) * C
# >>> print(model)
# <Node: 'Multiply_6'>
# <Multiply: 'Multiply_14' ([2, 2])>
runtime = ng.runtime(backend_name='INTERPRETER')
runtime = ng.runtime(backend_name='CPU')
# >>> print(runtime)
# <Runtime: Manager='INTERPRETER'>
# <Runtime: Backend='CPU'>
computation = runtime.computation(model, A, B, C)
# >>> print(computation)
# <Computation: Multiply_6(A, B, C)>
# <Computation: Multiply_14(A, B, C)>
value_a = np.array([[1, 2], [3, 4]], dtype=np.float32)
value_b = np.array([[5, 6], [7, 8]], dtype=np.float32)
......
......@@ -63,8 +63,9 @@ def as_elementwise_compatible_nodes(*input_values): # type: (*NodeInput) -> Lis
if len(shapes) > 1:
log.warning('More than one different shape in input nodes %s.', input_nodes)
types = {node.get_element_type() for node in input_nodes}
if len(types) > 1:
types = [node.get_element_type() for node in input_nodes]
unique_types = {repr(type) for type in types}
if len(unique_types) > 1:
log.warning('More than one different data type in input nodes %s.', input_nodes)
sorted_shapes = sorted(shapes, key=len)
......
......@@ -69,6 +69,6 @@ void regclass_pyngraph_Node(py::module m)
node.def("get_shape", &ngraph::Node::get_shape);
node.def("get_argument", &ngraph::Node::get_argument);
node.def_property("name", &ngraph::Node::get_name, &ngraph::Node::set_name);
node.def_property("name", &ngraph::Node::get_friendly_name, &ngraph::Node::set_name);
node.def_property_readonly("shape", &ngraph::Node::get_shape);
}
......@@ -33,7 +33,8 @@ void regclass_pyngraph_op_Parameter(py::module m)
std::string class_name = py::cast(self).get_type().attr("__name__").cast<std::string>();
std::string shape = py::cast(self.get_shape()).attr("__str__")().cast<std::string>();
std::string type = self.get_element_type().c_type_string();
return "<" + class_name + ": '" + self.get_name() + "' (" + shape + ", " + type + ")>";
return "<" + class_name + ": '" + self.get_friendly_name() + "' (" + shape + ", " + type +
")>";
});
parameter.def(py::init<const ngraph::element::Type&, const ngraph::Shape&>());
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment