Unverified Commit 602b06e2 authored by Fabian Boemer's avatar Fabian Boemer Committed by GitHub

Add set_config option to python backend (#4373)

* Add set_config option to python backend

* Address reviewer comment
Co-authored-by: 's avatarRobert Kimball <robert.kimball@intel.com>
parent 54132cc9
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
# ****************************************************************************** # ******************************************************************************
"""Provide a layer of abstraction for the ngraph++ runtime environment.""" """Provide a layer of abstraction for the ngraph++ runtime environment."""
import logging import logging
from typing import List, Union from typing import Dict, List, Union
import numpy as np import numpy as np
...@@ -42,6 +42,10 @@ class Runtime: ...@@ -42,6 +42,10 @@ class Runtime:
self.backend_name = backend_name self.backend_name = backend_name
self.backend = Backend.create(backend_name) self.backend = Backend.create(backend_name)
def set_config(self, config): # type: (Dict[str, str]) -> None
"""Set the backend configuration."""
self.backend.set_config(config, '')
def __repr__(self): # type: () -> str def __repr__(self): # type: () -> str
return "<Runtime: Backend='{}'>".format(self.backend_name) return "<Runtime: Backend='{}'>".format(self.backend_name)
......
...@@ -48,4 +48,5 @@ void regclass_pyngraph_runtime_Backend(py::module m) ...@@ -48,4 +48,5 @@ void regclass_pyngraph_runtime_Backend(py::module m)
const ngraph::element::Type&, const ngraph::Shape&)) & const ngraph::element::Type&, const ngraph::Shape&)) &
ngraph::runtime::Backend::create_tensor); ngraph::runtime::Backend::create_tensor);
backend.def("compile", &compile); backend.def("compile", &compile);
backend.def("set_config", &ngraph::runtime::Backend::set_config);
} }
...@@ -211,3 +211,9 @@ def test_constant_get_data_unsigned_integer(data_type): ...@@ -211,3 +211,9 @@ def test_constant_get_data_unsigned_integer(data_type):
node = ng.constant(input_data, dtype=data_type) node = ng.constant(input_data, dtype=data_type)
retrieved_data = node.get_data() retrieved_data = node.get_data()
assert np.allclose(input_data, retrieved_data) assert np.allclose(input_data, retrieved_data)
def test_backend_config():
dummy_config = {'dummy_option': 'dummy_value'}
# Expect no throw
ng.runtime(backend_name=test.BACKEND_NAME).set_config(dummy_config)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment