Commit db79474f authored by Scott Cyphers's avatar Scott Cyphers

Allow ndarray in constructor

parent 1b3f0e07
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <memory> #include <memory>
#include "ngraph/common.hpp" #include "ngraph/common.hpp"
#include "ngraph/runtime/ndarray.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -62,6 +63,17 @@ namespace ngraph ...@@ -62,6 +63,17 @@ namespace ngraph
make_primary_tensor_view(ET::element_type(), shape)); make_primary_tensor_view(ET::element_type(), shape));
} }
template <typename ET>
std::shared_ptr<ngraph::runtime::ParameterizedTensorView<ET>>
make_parameterized_tensor_view(const NDArrayBase<typename ET::type>& ndarray)
{
auto result =
std::dynamic_pointer_cast<ngraph::runtime::ParameterizedTensorView<ET>>(
make_primary_tensor_view(ET::element_type(), ndarray.get_shape()));
*result = ndarray;
return result;
}
/// @brief Construct a tuple handle from a sequence of values. /// @brief Construct a tuple handle from a sequence of values.
virtual std::shared_ptr<ngraph::runtime::Tuple> virtual std::shared_ptr<ngraph::runtime::Tuple>
make_tuple(const std::vector<std::shared_ptr<ngraph::runtime::Value>>& elements); make_tuple(const std::vector<std::shared_ptr<ngraph::runtime::Value>>& elements);
......
...@@ -53,6 +53,12 @@ namespace ngraph ...@@ -53,6 +53,12 @@ namespace ngraph
ParameterizedTensorView( ParameterizedTensorView(
const std::shared_ptr<ngraph::descriptor::TensorView>& descriptor); const std::shared_ptr<ngraph::descriptor::TensorView>& descriptor);
ParameterizedTensorView(const NDArrayBase<typename ET::type>& initializer)
: ParameterizedTensorView(initializer.get_shape())
{
m_vector = initializer.get_vector();
}
using element_type = ET; using element_type = ET;
using value_type = typename ET::type; using value_type = typename ET::type;
using storage_type = std::vector<value_type>; using storage_type = std::vector<value_type>;
...@@ -64,8 +70,8 @@ namespace ngraph ...@@ -64,8 +70,8 @@ namespace ngraph
return *this; return *this;
} }
template <typename T, size_t N> template <typename T>
ParameterizedTensorView<ET>& operator=(const NDArray<T, N>& ndarray) ParameterizedTensorView<ET>& operator=(const NDArrayBase<T>& ndarray)
{ {
assert(ndarray.get_shape() == get_shape()); assert(ndarray.get_shape() == get_shape());
std::copy(ndarray.begin(), ndarray.end(), m_vector.begin()); std::copy(ndarray.begin(), ndarray.end(), m_vector.begin());
......
...@@ -36,12 +36,12 @@ TEST(execute, abc) ...@@ -36,12 +36,12 @@ TEST(execute, abc)
auto cf = backend->make_call_frame(external); auto cf = backend->make_call_frame(external);
// Create some tensors for input/output // Create some tensors for input/output
auto a = backend->make_parameterized_tensor_view<element::Float32>(shape); auto a = backend->make_parameterized_tensor_view<element::Float32>(
*a = runtime::NDArray<float, 2>({{1, 2}, {3, 4}}); runtime::NDArray<float, 2>({{1, 2}, {3, 4}}));
auto b = backend->make_parameterized_tensor_view<element::Float32>(shape); auto b = backend->make_parameterized_tensor_view<element::Float32>(
*b = runtime::NDArray<float, 2>({{5, 6}, {7, 8}}); runtime::NDArray<float, 2>({{5, 6}, {7, 8}}));
auto c = backend->make_parameterized_tensor_view<element::Float32>(shape); auto c = backend->make_parameterized_tensor_view<element::Float32>(
*c = runtime::NDArray<float, 2>({{9, 10}, {11, 12}}); runtime::NDArray<float, 2>({{9, 10}, {11, 12}}));
auto result = backend->make_parameterized_tensor_view<element::Float32>(shape); auto result = backend->make_parameterized_tensor_view<element::Float32>(shape);
(*cf)({a, b, c}, {result}); (*cf)({a, b, c}, {result});
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment