Commit ec919553 authored by Scott Cyphers's avatar Scott Cyphers

Tensors with eigen implementations for addition, multiplication

Start of external function calling support
parent 8a6c08df
...@@ -11,6 +11,8 @@ ...@@ -11,6 +11,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
include_directories(SYSTEM ${EIGEN_INCLUDE_DIR})
set (SRC set (SRC
log.cpp log.cpp
ngraph/descriptor/input.cpp ngraph/descriptor/input.cpp
...@@ -24,6 +26,9 @@ set (SRC ...@@ -24,6 +26,9 @@ set (SRC
ngraph/pass/propagate_types.cpp ngraph/pass/propagate_types.cpp
ngraph/pass/topological_sort.cpp ngraph/pass/topological_sort.cpp
ngraph/pass/tree_pass.cpp ngraph/pass/tree_pass.cpp
ngraph/runtime/call_frame.cpp
ngraph/runtime/eigen/tensor_view.cpp
ngraph/shape.cpp
ngraph/visualize.cpp ngraph/visualize.cpp
ops/binary_elementwise_builtin.cpp ops/binary_elementwise_builtin.cpp
ops/broadcast.cpp ops/broadcast.cpp
...@@ -88,3 +93,4 @@ install(DIRECTORY ...@@ -88,3 +93,4 @@ install(DIRECTORY
FILES_MATCHING PATTERN "*.hpp" FILES_MATCHING PATTERN "*.hpp"
) )
add_dependencies(ngraph eigen)
...@@ -41,4 +41,11 @@ namespace ngraph ...@@ -41,4 +41,11 @@ namespace ngraph
/// A set of axes, for example, reduction axes /// A set of axes, for example, reduction axes
using AxisSet = std::set<size_t>; using AxisSet = std::set<size_t>;
/// Shape for a tensor
using Shape = std::vector<size_t>;
/// Strides of a tensor
using Strides = std::vector<size_t>;
} }
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace descriptor
{
// A buffer identfies a chunk of storage
// In descriptors, we are identifying what will be associated with actual memory
// during execution.
class Buffer
{
protected:
size_t size;
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <memory>
#include <vector>
#include "ngraph/descriptor/tensor_view.hpp"
#include "ngraph/function.hpp"
namespace ngraph
{
namespace descriptor
{
// Describes the frame that will be used when a function is executing
class CallFrame
{
protected:
Function m_function;
// Will be provided by the caller
std::vector<std::shared_ptr<TensorView>> m_inputs;
std::vector<std::shared_ptr<TensorView>> m_outputs;
// Will be provided by the call mechanism
// Expect there to be only one buffer
std::vector<std::shared_ptr<Buffer>> m_buffers;
};
}
}
...@@ -20,12 +20,21 @@ namespace ngraph ...@@ -20,12 +20,21 @@ namespace ngraph
{ {
namespace descriptor namespace descriptor
{ {
using Strides = std::vector<size_t>; // An interface for describing implementations of tensor views
// Kernel selection will need to pay attention to the layout
class TensorViewLayout class TensorViewLayout
{ {
public:
virtual ~TensorViewLayout() {}
};
// The standard strided layout
class DenseTensorViewLayout : public TensorViewLayout
{
protected: protected:
Strides m_strides; std::shared_ptr<Buffer> m_buffer;
Strides m_strides;
size_t m_offset;
}; };
} }
} }
...@@ -19,6 +19,8 @@ ...@@ -19,6 +19,8 @@
#pragma once #pragma once
#include "ngraph/common.hpp" #include "ngraph/common.hpp"
#include "ngraph/descriptor/buffer.hpp"
#include "ngraph/descriptor/call_frame.hpp"
#include "ngraph/descriptor/input.hpp" #include "ngraph/descriptor/input.hpp"
#include "ngraph/descriptor/output.hpp" #include "ngraph/descriptor/output.hpp"
#include "ngraph/descriptor/tensor.hpp" #include "ngraph/descriptor/tensor.hpp"
...@@ -42,5 +44,9 @@ ...@@ -42,5 +44,9 @@
#include "ngraph/ops/parameter.hpp" #include "ngraph/ops/parameter.hpp"
#include "ngraph/ops/subtract.hpp" #include "ngraph/ops/subtract.hpp"
#include "ngraph/ops/tuple.hpp" #include "ngraph/ops/tuple.hpp"
#include "ngraph/runtime/eigen/tensor_view.hpp"
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/function.hpp"
#include "ngraph/runtime/tensor_view.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "ngraph/type.hpp" #include "ngraph/type.hpp"
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <sstream> #include <sstream>
#include "ngraph/element_type.hpp" #include "ngraph/element_type.hpp"
#include "ngraph/runtime/eigen/tensor_view.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -59,7 +60,10 @@ namespace ngraph ...@@ -59,7 +60,10 @@ namespace ngraph
return ss.str(); return ss.str();
} }
typename T::type get_value() const { return m_value; } type get_value() const
{
return m_value;
}
protected: protected:
typename T::type m_value; typename T::type m_value;
...@@ -72,5 +76,55 @@ namespace ngraph ...@@ -72,5 +76,55 @@ namespace ngraph
using UInt8ScalarConstant = ScalarConstant<element::UInt8>; using UInt8ScalarConstant = ScalarConstant<element::UInt8>;
using UInt32ScalarConstant = ScalarConstant<element::UInt32>; using UInt32ScalarConstant = ScalarConstant<element::UInt32>;
using UInt64ScalarConstant = ScalarConstant<element::UInt64>; using UInt64ScalarConstant = ScalarConstant<element::UInt64>;
// Defines methods to all constant tensors
class TensorConstantBase : public Node
{
protected:
TensorConstantBase(const std::shared_ptr<TensorViewType>& type)
: Node({}, type)
{
}
virtual void propagate_types() override;
};
// Implement a constant tensor for each element type.
template <typename T>
class TensorConstant : public TensorConstantBase
{
public:
// The ngraph element type
using element_type = T;
// The C++ type that holds the element type
using type = typename T::type;
TensorConstant(const Shape& shape)
: TensorConstantBase(std::make_shared<TensorViewType>(T::element_type(), shape))
, m_value(std::make_shared<ngraph::runtime::eigen::PrimaryTensorView<T>>(shape))
{
}
virtual std::string description() const override { return "TensorConstant"; }
virtual std::string get_node_id() const override
{
std::stringstream ss;
ss << description() << "_" /* << node_id() */;
return ss.str();
}
typename std::shared_ptr<ngraph::runtime::eigen::PrimaryTensorView<T>> get_value() const { return m_value; }
protected:
std::shared_ptr<ngraph::runtime::eigen::PrimaryTensorView<T>> m_value;
};
using Float32TensorConstant = TensorConstant<element::Float32>;
using Int8TensorConstant = TensorConstant<element::Int8>;
using Int32TensorConstant = TensorConstant<element::Int32>;
using Int64TensorConstant = TensorConstant<element::Int64>;
using UInt8TensorConstant = TensorConstant<element::UInt8>;
using UInt32TensorConstant = TensorConstant<element::UInt32>;
using UInt64TensorConstant = TensorConstant<element::UInt64>;
} }
} }
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/ngraph.hpp"
using namespace std;
using namespace ngraph;
using namespace runtime;
CallFrame::CallFrame(Function& function,
const std::vector<std::shared_ptr<PrimaryTensorView>>& arguments,
const std::vector<std::shared_ptr<PrimaryTensorView>>& results)
{
m_tensors.insert(m_tensors.end(), arguments.begin(), arguments.end());
m_tensors.insert(m_tensors.end(), results.begin(), results.end());
// TBD
// From Function allocate tensors for the temporaries
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <memory>
#include <vector>
#include "ngraph/runtime/function.hpp"
namespace ngraph
{
namespace runtime
{
class CallFrameAccessor;
// This is constructed when a runtime function is called.
class CallFrame
{
friend class CallFrameAccessor;
public:
CallFrame(Function& function,
const std::vector<std::shared_ptr<PrimaryTensorView>>& arguments,
const std::vector<std::shared_ptr<PrimaryTensorView>>& results);
protected:
std::vector<std::shared_ptr<PrimaryTensorView>> m_tensors;
};
class CallFrameAccessor
{
public:
CallFrameAccessor(size_t index)
: m_index(index)
{
}
std::shared_ptr<PrimaryTensorView> operator()(CallFrame& call_frame)
{
return call_frame.m_tensors[m_index];
}
protected:
size_t m_index;
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <Eigen/Dense>
#include "ngraph.hpp"
using namespace Eigen;
using namespace ngraph::runtime::eigen;
using namespace ngraph::element;
template void ngraph::runtime::eigen::add<Float32>(const PrimaryTensorView<Float32>& arg0,
const PrimaryTensorView<Float32>& arg1,
PrimaryTensorView<Float32>& out);
template void ngraph::runtime::eigen::multiply<Float32>(const PrimaryTensorView<Float32>& arg0,
const PrimaryTensorView<Float32>& arg1,
PrimaryTensorView<Float32>& out);
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <Eigen/Dense>
#include <vector>
#include "ngraph/shape.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace eigen
{
template <typename ET>
class PrimaryTensorView : public ngraph::runtime::PrimaryTensorView
{
public:
// Standard definitions from vector
using value_type = typename ET::type;
using storage_type = std::vector<value_type>;
using size_type = typename storage_type::size_type;
using difference_type = typename storage_type::difference_type;
using reference = typename storage_type::reference;
using const_reference = typename storage_type::const_reference;
using pointer = typename storage_type::pointer;
using const_pointer = typename storage_type::const_pointer;
using iterator = typename storage_type::iterator;
using const_iterator = typename storage_type::const_iterator;
using reverse_iterator = typename storage_type::reverse_iterator;
using const_reverse_iterator = typename storage_type::const_reverse_iterator;
// Mapping vector to eigen
using eigen_type = Eigen::Array<value_type, Eigen::Dynamic, 1>;
using eigen_map = Eigen::Map<eigen_type>;
PrimaryTensorView(const ngraph::Shape& shape)
: m_shape(shape)
, m_size(ngraph::shape_size(shape))
, m_strides(ngraph::row_major_strides(m_shape))
, m_vector(m_size, 0)
, m_map(&m_vector[0], m_size, 1)
{
}
template <typename T>
PrimaryTensorView& operator=(const T& value)
{
m_vector = value;
return *this;
}
// For getting the data out
const storage_type& get_vector() { return m_vector; }
eigen_map& get_map() { return m_map; }
const eigen_map& get_map() const { return m_map; }
const Shape& get_shape() const { return m_shape; }
protected:
ngraph::Shape m_shape;
size_t m_size;
ngraph::Strides m_strides;
storage_type m_vector;
eigen_map m_map;
};
template <typename ET>
void add(const PrimaryTensorView<ET>& arg0,
const PrimaryTensorView<ET>& arg1,
PrimaryTensorView<ET>& out)
{
out.get_map() = arg0.get_map() + arg1.get_map();
}
template <typename ET>
void multiply(const PrimaryTensorView<ET>& arg0,
const PrimaryTensorView<ET>& arg1,
PrimaryTensorView<ET>& out)
{
out.get_map() = arg0.get_map() * arg1.get_map();
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <memory>
#include <vector>
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
// A compiled graph function
class Function
{
public:
virtual ~Function() {}
// Invoke the function with a the given inputs and outputs
void operator()(std::vector<std::shared_ptr<PrimaryTensorView>> inputs,
std::vector<std::shared_ptr<PrimaryTensorView>> outputs);
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace runtime
{
// Actual tensor views are parameterized on element type
class PrimaryTensorView
{
public:
virtual ~PrimaryTensorView(){}
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <algorithm>
#include <vector>
#include "ngraph/shape.hpp"
using namespace std;
using namespace ngraph;
size_t ngraph::shape_size(const Shape& shape)
{
size_t size = 1;
for (auto d : shape)
{
size *= d;
}
return size;
}
Strides ngraph::row_major_strides(const Shape& shape)
{
Strides strides;
size_t s = 1;
for (auto d = shape.rbegin(); d != shape.rend(); d++)
{
strides.push_back(s);
s *= *d;
}
reverse(strides.begin(), strides.end());
return strides;
}
...@@ -16,31 +16,13 @@ ...@@ -16,31 +16,13 @@
#include <cstddef> #include <cstddef>
#include <vector> #include <vector>
#include "common.hpp"
namespace ngraph namespace ngraph
{ {
/** /// Number of elements in spanned by a shape
** Holds the shape of a tensor view. size_t shape_size(const Shape& shape);
**/
class Shape
{
public:
/// @param sizes A sequence of sizes.
Shape(const std::initializer_list<size_t>& sizes)
: m_sizes(sizes)
{
}
Shape(const std::vector<size_t>& sizes)
: m_sizes(sizes)
{
}
/// Conversion to a vector of sizes.
operator const std::vector<size_t>&() const { return m_sizes; }
bool operator==(const Shape& shape) const { return m_sizes == shape.m_sizes; }
bool operator!=(const Shape& shape) const { return m_sizes != shape.m_sizes; }
protected: /// Row-major strides for a shape
std::vector<size_t> m_sizes; Strides row_major_strides(const Shape& shape);
};
} }
...@@ -17,3 +17,6 @@ ...@@ -17,3 +17,6 @@
using namespace ngraph::op; using namespace ngraph::op;
void ScalarConstantBase::propagate_types() {} void ScalarConstantBase::propagate_types() {}
void TensorConstantBase::propagate_types() {}
...@@ -56,4 +56,4 @@ std::ostream& ngraph::element::operator<<(std::ostream& out, const ngraph::eleme ...@@ -56,4 +56,4 @@ std::ostream& ngraph::element::operator<<(std::ostream& out, const ngraph::eleme
{ {
// out << "ElementType(" << obj.c_type_string() << ")"; // out << "ElementType(" << obj.c_type_string() << ")";
return out; return out;
} }
\ No newline at end of file
...@@ -29,6 +29,8 @@ set (SRC ...@@ -29,6 +29,8 @@ set (SRC
main.cpp main.cpp
op.cpp op.cpp
pass_manager.cpp pass_manager.cpp
runtime.cpp
shape.cpp
tensor.cpp tensor.cpp
test_tools.cpp test_tools.cpp
topological_sort.cpp topological_sort.cpp
......
...@@ -98,40 +98,84 @@ TEST(build_graph, literal) ...@@ -98,40 +98,84 @@ TEST(build_graph, literal)
ASSERT_NE(*int32_0->get_value_type(), *float_scalar_type); ASSERT_NE(*int32_0->get_value_type(), *float_scalar_type);
} }
TEST(build_graph, tensor)
{
// float scalar from a float
//auto float0 = FloatScalarConstant::make(3.0);
auto float0 = make_shared<op::Float32TensorConstant>(Shape{2, 3});
auto float_tensor_type =
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 3});
ASSERT_EQ(*float0->get_value_type(), *float_tensor_type);
auto d = make_shared<op::Dot>(float0, float0);
ASSERT_EQ(d->get_arguments().at(0), float0);
ASSERT_EQ(d->get_arguments().at(1), float0);
auto int32_0 = make_shared<op::Int32TensorConstant>(Shape{3, 5});
auto int32_tensor_type =
make_shared<TensorViewType>(element::Int32::element_type(), Shape{3, 5});
ASSERT_EQ(*int32_0->get_value_type(), *int32_tensor_type);
ASSERT_NE(*int32_0->get_value_type(), *float_tensor_type);
}
TEST(build_graph, set_value_type_checked) TEST(build_graph, set_value_type_checked)
{ {
auto untyped_param = make_shared<op::Parameter>(); auto untyped_param = make_shared<op::Parameter>();
try { try
untyped_param->set_value_type_checked(make_shared<TensorViewType>(element::Float32::element_type(), Shape{4, 4})); {
} catch(...){ untyped_param->set_value_type_checked(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{4, 4}));
}
catch (...)
{
FAIL() << "Setting value type for first time type failed."; FAIL() << "Setting value type for first time type failed.";
} }
try { try
untyped_param->set_value_type_checked(make_shared<TensorViewType>(element::Float32::element_type(), Shape{4, 4})); {
} catch(...){ untyped_param->set_value_type_checked(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{4, 4}));
}
catch (...)
{
FAIL() << "Setting value type to same type failed."; FAIL() << "Setting value type to same type failed.";
} }
try { try
untyped_param->set_value_type_checked(make_shared<TensorViewType>(element::Float32::element_type(), Shape{4, 5})); {
untyped_param->set_value_type_checked(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{4, 5}));
FAIL() << "Setting value type to a different shape did not fail."; FAIL() << "Setting value type to a different shape did not fail.";
} catch(const ngraph_error& error){ }
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(), std::string("Setting value type to a different ValueType")); EXPECT_EQ(error.what(), std::string("Setting value type to a different ValueType"));
} catch(...){ }
catch (...)
{
FAIL() << "Setting value type to a different shape did not failed with incorrect error."; FAIL() << "Setting value type to a different shape did not failed with incorrect error.";
} }
try { try
untyped_param->set_value_type_checked(make_shared<TensorViewType>(element::Int32::element_type(), Shape{4, 4})); {
untyped_param->set_value_type_checked(
make_shared<TensorViewType>(element::Int32::element_type(), Shape{4, 4}));
FAIL() << "Setting value type to a different element type did not fail."; FAIL() << "Setting value type to a different element type did not fail.";
} catch(const ngraph_error& error){ }
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(), std::string("Setting value type to a different ValueType")); EXPECT_EQ(error.what(), std::string("Setting value type to a different ValueType"));
} catch(...){
FAIL() << "Setting value type to a different element type did not failed with incorrect error.";
} }
catch (...)
{
FAIL() << "Setting value type to a different element type did not failed with incorrect "
"error.";
}
auto param = make_shared<op::Parameter>(element::Float32::element_type(), Shape{4, 4}); auto param = make_shared<op::Parameter>(element::Float32::element_type(), Shape{4, 4});
try { try
param->set_value_type_checked(make_shared<TensorViewType>(element::Float32::element_type(), Shape{4, 4})); {
} catch(...){ param->set_value_type_checked(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{4, 4}));
}
catch (...)
{
FAIL() << "Setting value type to same type failed."; FAIL() << "Setting value type to same type failed.";
} }
} }
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <memory>
#include <vector>
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
using namespace std;
using namespace ngraph;
using namespace ngraph::runtime::eigen;
TEST(runtime, test_add)
{
auto x = make_shared<PrimaryTensorView<element::Float32>>(Shape{2, 2});
*x = std::vector<float>{1, 2, 3, 4};
auto y = make_shared<PrimaryTensorView<element::Float32>>(Shape{2, 2});
*y = std::vector<float>{5, 6, 7, 8};
auto z = make_shared<PrimaryTensorView<element::Float32>>(Shape{2, 2});
add(*x, *y, *z);
ASSERT_EQ((vector<float>{6, 8, 10, 12}), z->get_vector());
}
TEST(runtime, test_multiply)
{
auto x = make_shared<op::Float32TensorConstant>(Shape{2, 2});
*x->get_value() = std::vector<float>{1, 2, 3, 4};
auto y = make_shared<op::Float32TensorConstant>(Shape{2, 2});
*y->get_value() = std::vector<float>{5, 6, 7, 8};
auto z = make_shared<op::Float32TensorConstant>(Shape{2, 2});
multiply(*x->get_value(), *y->get_value(), *z->get_value());
ASSERT_EQ((vector<float>{5, 12, 21, 32}), z->get_value()->get_vector());
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <memory>
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
using namespace std;
using namespace ngraph;
using namespace ngraph::runtime::eigen;
TEST(shape, test_shape_size)
{
ASSERT_EQ(1, shape_size(Shape{}));
ASSERT_EQ(2 * 3 * 5, shape_size(Shape{2, 3, 5}));
}
TEST(shape, test_shape_strides)
{
ASSERT_EQ(Strides{}, row_major_strides(Shape{}));
ASSERT_EQ(Strides{1}, row_major_strides(Shape{3}));
ASSERT_EQ((Strides{7, 1}), row_major_strides(Shape{2, 7}));
ASSERT_EQ((Strides{84, 12, 1}), row_major_strides(Shape{5, 7, 12}));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment