Commit f7bfd75e authored by Bob Kimball's avatar Bob Kimball

merge master

parents bf022034 70f8c112
......@@ -11,6 +11,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
include_directories(SYSTEM ${EIGEN_INCLUDE_DIR})
set (SRC
log.cpp
ngraph/descriptor/input.cpp
......@@ -28,6 +30,9 @@ set (SRC
ngraph/pass/propagate_types.cpp
ngraph/pass/topological_sort.cpp
ngraph/pass/tree_pass.cpp
ngraph/runtime/call_frame.cpp
ngraph/runtime/eigen/tensor_view.cpp
ngraph/shape.cpp
ngraph/visualize.cpp
ops/binary_elementwise_builtin.cpp
ops/broadcast.cpp
......@@ -92,3 +97,4 @@ install(DIRECTORY
FILES_MATCHING PATTERN "*.hpp"
)
add_dependencies(ngraph eigen)
......@@ -41,4 +41,11 @@ namespace ngraph
/// A set of axes, for example, reduction axes
using AxisSet = std::set<size_t>;
/// Shape for a tensor
using Shape = std::vector<size_t>;
/// Strides of a tensor
using Strides = std::vector<size_t>;
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace descriptor
{
// A buffer identfies a chunk of storage
// In descriptors, we are identifying what will be associated with actual memory
// during execution.
class Buffer
{
protected:
size_t size;
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <memory>
#include <vector>
#include "ngraph/descriptor/tensor_view.hpp"
#include "ngraph/function.hpp"
namespace ngraph
{
namespace descriptor
{
// Describes the frame that will be used when a function is executing
class CallFrame
{
protected:
Function m_function;
// Will be provided by the caller
std::vector<std::shared_ptr<TensorView>> m_inputs;
std::vector<std::shared_ptr<TensorView>> m_outputs;
// Will be provided by the call mechanism
// Expect there to be only one buffer
std::vector<std::shared_ptr<Buffer>> m_buffers;
};
}
}
......@@ -20,12 +20,21 @@ namespace ngraph
{
namespace descriptor
{
using Strides = std::vector<size_t>;
// An interface for describing implementations of tensor views
// Kernel selection will need to pay attention to the layout
class TensorViewLayout
{
public:
virtual ~TensorViewLayout() {}
};
// The standard strided layout
class DenseTensorViewLayout : public TensorViewLayout
{
protected:
Strides m_strides;
std::shared_ptr<Buffer> m_buffer;
Strides m_strides;
size_t m_offset;
};
}
}
......@@ -19,6 +19,8 @@
#pragma once
#include "ngraph/common.hpp"
#include "ngraph/descriptor/buffer.hpp"
#include "ngraph/descriptor/call_frame.hpp"
#include "ngraph/descriptor/input.hpp"
#include "ngraph/descriptor/output.hpp"
#include "ngraph/descriptor/tensor.hpp"
......@@ -42,5 +44,9 @@
#include "ngraph/ops/parameter.hpp"
#include "ngraph/ops/subtract.hpp"
#include "ngraph/ops/tuple.hpp"
#include "ngraph/runtime/eigen/tensor_view.hpp"
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/function.hpp"
#include "ngraph/runtime/tensor_view.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/type.hpp"
......@@ -17,6 +17,7 @@
#include <sstream>
#include "ngraph/element_type.hpp"
#include "ngraph/runtime/eigen/tensor_view.hpp"
namespace ngraph
{
......@@ -59,7 +60,10 @@ namespace ngraph
return ss.str();
}
typename T::type get_value() const { return m_value; }
type get_value() const
{
return m_value;
}
protected:
typename T::type m_value;
......@@ -72,5 +76,55 @@ namespace ngraph
using UInt8ScalarConstant = ScalarConstant<element::UInt8>;
using UInt32ScalarConstant = ScalarConstant<element::UInt32>;
using UInt64ScalarConstant = ScalarConstant<element::UInt64>;
// Defines methods to all constant tensors
class TensorConstantBase : public Node
{
protected:
TensorConstantBase(const std::shared_ptr<TensorViewType>& type)
: Node({}, type)
{
}
virtual void propagate_types() override;
};
// Implement a constant tensor for each element type.
template <typename T>
class TensorConstant : public TensorConstantBase
{
public:
// The ngraph element type
using element_type = T;
// The C++ type that holds the element type
using type = typename T::type;
TensorConstant(const Shape& shape)
: TensorConstantBase(std::make_shared<TensorViewType>(T::element_type(), shape))
, m_value(std::make_shared<ngraph::runtime::eigen::PrimaryTensorView<T>>(shape))
{
}
virtual std::string description() const override { return "TensorConstant"; }
virtual std::string get_node_id() const override
{
std::stringstream ss;
ss << description() << "_" /* << node_id() */;
return ss.str();
}
typename std::shared_ptr<ngraph::runtime::eigen::PrimaryTensorView<T>> get_value() const { return m_value; }
protected:
std::shared_ptr<ngraph::runtime::eigen::PrimaryTensorView<T>> m_value;
};
using Float32TensorConstant = TensorConstant<element::Float32>;
using Int8TensorConstant = TensorConstant<element::Int8>;
using Int32TensorConstant = TensorConstant<element::Int32>;
using Int64TensorConstant = TensorConstant<element::Int64>;
using UInt8TensorConstant = TensorConstant<element::UInt8>;
using UInt32TensorConstant = TensorConstant<element::UInt32>;
using UInt64TensorConstant = TensorConstant<element::UInt64>;
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/ngraph.hpp"
using namespace std;
using namespace ngraph;
using namespace runtime;
CallFrame::CallFrame(Function& function,
const std::vector<std::shared_ptr<PrimaryTensorView>>& arguments,
const std::vector<std::shared_ptr<PrimaryTensorView>>& results)
{
m_tensors.insert(m_tensors.end(), arguments.begin(), arguments.end());
m_tensors.insert(m_tensors.end(), results.begin(), results.end());
// TBD
// From Function allocate tensors for the temporaries
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <memory>
#include <vector>
#include "ngraph/runtime/function.hpp"
namespace ngraph
{
namespace runtime
{
class CallFrameAccessor;
// This is constructed when a runtime function is called.
class CallFrame
{
friend class CallFrameAccessor;
public:
CallFrame(Function& function,
const std::vector<std::shared_ptr<PrimaryTensorView>>& arguments,
const std::vector<std::shared_ptr<PrimaryTensorView>>& results);
protected:
std::vector<std::shared_ptr<PrimaryTensorView>> m_tensors;
};
class CallFrameAccessor
{
public:
CallFrameAccessor(size_t index)
: m_index(index)
{
}
std::shared_ptr<PrimaryTensorView> operator()(CallFrame& call_frame)
{
return call_frame.m_tensors[m_index];
}
protected:
size_t m_index;
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <Eigen/Dense>
#include "ngraph.hpp"
using namespace Eigen;
using namespace ngraph::runtime::eigen;
using namespace ngraph::element;
template void ngraph::runtime::eigen::add<Float32>(const PrimaryTensorView<Float32>& arg0,
const PrimaryTensorView<Float32>& arg1,
PrimaryTensorView<Float32>& out);
template void ngraph::runtime::eigen::multiply<Float32>(const PrimaryTensorView<Float32>& arg0,
const PrimaryTensorView<Float32>& arg1,
PrimaryTensorView<Float32>& out);
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <Eigen/Dense>
#include <vector>
#include "ngraph/shape.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace eigen
{
template <typename ET>
class PrimaryTensorView : public ngraph::runtime::PrimaryTensorView
{
public:
// Standard definitions from vector
using value_type = typename ET::type;
using storage_type = std::vector<value_type>;
using size_type = typename storage_type::size_type;
using difference_type = typename storage_type::difference_type;
using reference = typename storage_type::reference;
using const_reference = typename storage_type::const_reference;
using pointer = typename storage_type::pointer;
using const_pointer = typename storage_type::const_pointer;
using iterator = typename storage_type::iterator;
using const_iterator = typename storage_type::const_iterator;
using reverse_iterator = typename storage_type::reverse_iterator;
using const_reverse_iterator = typename storage_type::const_reverse_iterator;
// Mapping vector to eigen
using eigen_type = Eigen::Array<value_type, Eigen::Dynamic, 1>;
using eigen_map = Eigen::Map<eigen_type>;
PrimaryTensorView(const ngraph::Shape& shape)
: m_shape(shape)
, m_size(ngraph::shape_size(shape))
, m_strides(ngraph::row_major_strides(m_shape))
, m_vector(m_size, 0)
, m_map(&m_vector[0], m_size, 1)
{
}
template <typename T>
PrimaryTensorView& operator=(const T& value)
{
m_vector = value;
return *this;
}
// For getting the data out
const storage_type& get_vector() { return m_vector; }
eigen_map& get_map() { return m_map; }
const eigen_map& get_map() const { return m_map; }
const Shape& get_shape() const { return m_shape; }
protected:
ngraph::Shape m_shape;
size_t m_size;
ngraph::Strides m_strides;
storage_type m_vector;
eigen_map m_map;
};
template <typename ET>
void add(const PrimaryTensorView<ET>& arg0,
const PrimaryTensorView<ET>& arg1,
PrimaryTensorView<ET>& out)
{
out.get_map() = arg0.get_map() + arg1.get_map();
}
template <typename ET>
void multiply(const PrimaryTensorView<ET>& arg0,
const PrimaryTensorView<ET>& arg1,
PrimaryTensorView<ET>& out)
{
out.get_map() = arg0.get_map() * arg1.get_map();
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <memory>
#include <vector>
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
// A compiled graph function
class Function
{
public:
virtual ~Function() {}
// Invoke the function with a the given inputs and outputs
void operator()(std::vector<std::shared_ptr<PrimaryTensorView>> inputs,
std::vector<std::shared_ptr<PrimaryTensorView>> outputs);
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace runtime
{
// Actual tensor views are parameterized on element type
class PrimaryTensorView
{
public:
virtual ~PrimaryTensorView(){}
};
}
}
......@@ -12,11 +12,33 @@
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "shape.hpp"
#include "util.hpp"
#include <algorithm>
#include <vector>
std::ostream& ngraph::operator<<(std::ostream& out, const ngraph::Shape& obj)
#include "ngraph/shape.hpp"
using namespace std;
using namespace ngraph;
size_t ngraph::shape_size(const Shape& shape)
{
size_t size = 1;
for (auto d : shape)
{
size *= d;
}
return size;
}
Strides ngraph::row_major_strides(const Shape& shape)
{
out << "{" << join(obj.m_sizes, ", ") << "}";
return out;
Strides strides;
size_t s = 1;
for (auto d = shape.rbegin(); d != shape.rend(); d++)
{
strides.push_back(s);
s *= *d;
}
reverse(strides.begin(), strides.end());
return strides;
}
......@@ -18,32 +18,13 @@
#include <iostream>
#include <vector>
#include "common.hpp"
namespace ngraph
{
/**
** Holds the shape of a tensor view.
**/
class Shape
{
public:
/// @param sizes A sequence of sizes.
Shape(const std::initializer_list<size_t>& sizes)
: m_sizes(sizes)
{
}
Shape(const std::vector<size_t>& sizes)
: m_sizes(sizes)
{
}
/// Conversion to a vector of sizes.
operator const std::vector<size_t>&() const { return m_sizes; }
bool operator==(const Shape& shape) const { return m_sizes == shape.m_sizes; }
bool operator!=(const Shape& shape) const { return m_sizes != shape.m_sizes; }
friend std::ostream& operator<<(std::ostream&, const Shape&);
/// Number of elements in spanned by a shape
size_t shape_size(const Shape& shape);
protected:
std::vector<size_t> m_sizes;
};
/// Row-major strides for a shape
Strides row_major_strides(const Shape& shape);
}
......@@ -17,3 +17,6 @@
using namespace ngraph::op;
void ScalarConstantBase::propagate_types() {}
void TensorConstantBase::propagate_types() {}
......@@ -55,4 +55,4 @@ std::ostream& ngraph::element::operator<<(std::ostream& out, const ngraph::eleme
{
out << obj.m_cname;
return out;
}
\ No newline at end of file
}
......@@ -16,6 +16,7 @@
#include "ngraph/ngraph.hpp"
#include "log.hpp"
#include "util.hpp"
using namespace std;
using namespace ngraph;
......@@ -69,7 +70,7 @@ std::ostream& ngraph::operator<<(std::ostream& out, const ValueType& obj)
std::ostream& ngraph::operator<<(std::ostream& out, const TensorViewType& obj)
{
out << "TensorViewType(" << obj.m_element_type << ", " << obj.m_shape << ")";
out << "TensorViewType(" << obj.m_element_type << ", {" << join(obj.m_shape) << "})";
return out;
}
......
......@@ -31,6 +31,8 @@ set (SRC
pass_liveness.cpp
pass_manager.cpp
pass_memory_layout.cpp
runtime.cpp
shape.cpp
tensor.cpp
test_tools.cpp
topological_sort.cpp
......
......@@ -98,40 +98,84 @@ TEST(build_graph, literal)
ASSERT_NE(*int32_0->get_value_type(), *float_scalar_type);
}
TEST(build_graph, tensor)
{
// float scalar from a float
//auto float0 = FloatScalarConstant::make(3.0);
auto float0 = make_shared<op::Float32TensorConstant>(Shape{2, 3});
auto float_tensor_type =
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 3});
ASSERT_EQ(*float0->get_value_type(), *float_tensor_type);
auto d = make_shared<op::Dot>(float0, float0);
ASSERT_EQ(d->get_arguments().at(0), float0);
ASSERT_EQ(d->get_arguments().at(1), float0);
auto int32_0 = make_shared<op::Int32TensorConstant>(Shape{3, 5});
auto int32_tensor_type =
make_shared<TensorViewType>(element::Int32::element_type(), Shape{3, 5});
ASSERT_EQ(*int32_0->get_value_type(), *int32_tensor_type);
ASSERT_NE(*int32_0->get_value_type(), *float_tensor_type);
}
TEST(build_graph, set_value_type_checked)
{
auto untyped_param = make_shared<op::Parameter>();
try {
untyped_param->set_value_type_checked(make_shared<TensorViewType>(element::Float32::element_type(), Shape{4, 4}));
} catch(...){
try
{
untyped_param->set_value_type_checked(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{4, 4}));
}
catch (...)
{
FAIL() << "Setting value type for first time type failed.";
}
try {
untyped_param->set_value_type_checked(make_shared<TensorViewType>(element::Float32::element_type(), Shape{4, 4}));
} catch(...){
try
{
untyped_param->set_value_type_checked(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{4, 4}));
}
catch (...)
{
FAIL() << "Setting value type to same type failed.";
}
try {
untyped_param->set_value_type_checked(make_shared<TensorViewType>(element::Float32::element_type(), Shape{4, 5}));
try
{
untyped_param->set_value_type_checked(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{4, 5}));
FAIL() << "Setting value type to a different shape did not fail.";
} catch(const ngraph_error& error){
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(), std::string("Setting value type to a different ValueType"));
} catch(...){
}
catch (...)
{
FAIL() << "Setting value type to a different shape did not failed with incorrect error.";
}
try {
untyped_param->set_value_type_checked(make_shared<TensorViewType>(element::Int32::element_type(), Shape{4, 4}));
try
{
untyped_param->set_value_type_checked(
make_shared<TensorViewType>(element::Int32::element_type(), Shape{4, 4}));
FAIL() << "Setting value type to a different element type did not fail.";
} catch(const ngraph_error& error){
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(), std::string("Setting value type to a different ValueType"));
} catch(...){
FAIL() << "Setting value type to a different element type did not failed with incorrect error.";
}
catch (...)
{
FAIL() << "Setting value type to a different element type did not failed with incorrect "
"error.";
}
auto param = make_shared<op::Parameter>(element::Float32::element_type(), Shape{4, 4});
try {
param->set_value_type_checked(make_shared<TensorViewType>(element::Float32::element_type(), Shape{4, 4}));
} catch(...){
try
{
param->set_value_type_checked(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{4, 4}));
}
catch (...)
{
FAIL() << "Setting value type to same type failed.";
}
}
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <memory>
#include <vector>
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
using namespace std;
using namespace ngraph;
using namespace ngraph::runtime::eigen;
TEST(runtime, test_add)
{
auto x = make_shared<PrimaryTensorView<element::Float32>>(Shape{2, 2});
*x = std::vector<float>{1, 2, 3, 4};
auto y = make_shared<PrimaryTensorView<element::Float32>>(Shape{2, 2});
*y = std::vector<float>{5, 6, 7, 8};
auto z = make_shared<PrimaryTensorView<element::Float32>>(Shape{2, 2});
add(*x, *y, *z);
ASSERT_EQ((vector<float>{6, 8, 10, 12}), z->get_vector());
}
TEST(runtime, test_multiply)
{
auto x = make_shared<op::Float32TensorConstant>(Shape{2, 2});
*x->get_value() = std::vector<float>{1, 2, 3, 4};
auto y = make_shared<op::Float32TensorConstant>(Shape{2, 2});
*y->get_value() = std::vector<float>{5, 6, 7, 8};
auto z = make_shared<op::Float32TensorConstant>(Shape{2, 2});
multiply(*x->get_value(), *y->get_value(), *z->get_value());
ASSERT_EQ((vector<float>{5, 12, 21, 32}), z->get_value()->get_vector());
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <memory>
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
using namespace std;
using namespace ngraph;
using namespace ngraph::runtime::eigen;
TEST(shape, test_shape_size)
{
ASSERT_EQ(1, shape_size(Shape{}));
ASSERT_EQ(2 * 3 * 5, shape_size(Shape{2, 3, 5}));
}
TEST(shape, test_shape_strides)
{
ASSERT_EQ(Strides{}, row_major_strides(Shape{}));
ASSERT_EQ(Strides{1}, row_major_strides(Shape{3}));
ASSERT_EQ((Strides{7, 1}), row_major_strides(Shape{2, 7}));
ASSERT_EQ((Strides{84, 12, 1}), row_major_strides(Shape{5, 7, 12}));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment