Commit db79474f authored by Scott Cyphers's avatar Scott Cyphers

Allow ndarray in constructor

parent 1b3f0e07
......@@ -17,6 +17,7 @@
#include <memory>
#include "ngraph/common.hpp"
#include "ngraph/runtime/ndarray.hpp"
namespace ngraph
{
......@@ -62,6 +63,17 @@ namespace ngraph
make_primary_tensor_view(ET::element_type(), shape));
}
template <typename ET>
std::shared_ptr<ngraph::runtime::ParameterizedTensorView<ET>>
make_parameterized_tensor_view(const NDArrayBase<typename ET::type>& ndarray)
{
auto result =
std::dynamic_pointer_cast<ngraph::runtime::ParameterizedTensorView<ET>>(
make_primary_tensor_view(ET::element_type(), ndarray.get_shape()));
*result = ndarray;
return result;
}
/// @brief Construct a tuple handle from a sequence of values.
virtual std::shared_ptr<ngraph::runtime::Tuple>
make_tuple(const std::vector<std::shared_ptr<ngraph::runtime::Value>>& elements);
......
......@@ -53,6 +53,12 @@ namespace ngraph
ParameterizedTensorView(
const std::shared_ptr<ngraph::descriptor::TensorView>& descriptor);
ParameterizedTensorView(const NDArrayBase<typename ET::type>& initializer)
: ParameterizedTensorView(initializer.get_shape())
{
m_vector = initializer.get_vector();
}
using element_type = ET;
using value_type = typename ET::type;
using storage_type = std::vector<value_type>;
......@@ -64,8 +70,8 @@ namespace ngraph
return *this;
}
template <typename T, size_t N>
ParameterizedTensorView<ET>& operator=(const NDArray<T, N>& ndarray)
template <typename T>
ParameterizedTensorView<ET>& operator=(const NDArrayBase<T>& ndarray)
{
assert(ndarray.get_shape() == get_shape());
std::copy(ndarray.begin(), ndarray.end(), m_vector.begin());
......
......@@ -36,12 +36,12 @@ TEST(execute, abc)
auto cf = backend->make_call_frame(external);
// Create some tensors for input/output
auto a = backend->make_parameterized_tensor_view<element::Float32>(shape);
*a = runtime::NDArray<float, 2>({{1, 2}, {3, 4}});
auto b = backend->make_parameterized_tensor_view<element::Float32>(shape);
*b = runtime::NDArray<float, 2>({{5, 6}, {7, 8}});
auto c = backend->make_parameterized_tensor_view<element::Float32>(shape);
*c = runtime::NDArray<float, 2>({{9, 10}, {11, 12}});
auto a = backend->make_parameterized_tensor_view<element::Float32>(
runtime::NDArray<float, 2>({{1, 2}, {3, 4}}));
auto b = backend->make_parameterized_tensor_view<element::Float32>(
runtime::NDArray<float, 2>({{5, 6}, {7, 8}}));
auto c = backend->make_parameterized_tensor_view<element::Float32>(
runtime::NDArray<float, 2>({{9, 10}, {11, 12}}));
auto result = backend->make_parameterized_tensor_view<element::Float32>(shape);
(*cf)({a, b, c}, {result});
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment