Commit b485bb33 authored by Adam Procter's avatar Adam Procter

De-Eigenize broadcast, and extend it to higher dimensions

parent c50164bc
......@@ -15,6 +15,7 @@ set (SRC
autodiff/adjoints.cpp
builder/autobroadcast.cpp
builder/reduce_ops.cpp
coordinate_iterator.cpp
descriptor/input.cpp
descriptor/layout/dense_tensor_view_layout.cpp
descriptor/layout/tensor_view_layout.cpp
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <cassert>
#include <cstdio>
#include <iostream>
#include <vector>
#include "ngraph/common.hpp"
#include "ngraph/coordinate_iterator.hpp"
#include "ngraph/except.hpp"
using namespace ngraph;
CoordinateIterator::CoordinateIterator(const Shape& space_shape,
const Strides& strides,
const Coordinate& window_outer_corner,
const Coordinate& window_inner_corner)
: m_space_shape(space_shape)
, m_strides(strides)
, m_window_outer_corner(window_outer_corner)
, m_window_inner_corner(window_inner_corner)
, m_current_coordinate(window_inner_corner)
{
assert(space_shape.size() == window_inner_corner.size());
assert(space_shape.size() == window_outer_corner.size());
assert(space_shape.size() == strides.size());
for (size_t i = 0; i < space_shape.size(); i++)
{
if (window_inner_corner[i] > window_outer_corner[i])
{
throw ngraph_error("Coordinate iterator inner corner is outside outer corner");
}
if (window_inner_corner[i] >= m_space_shape[i])
{
throw ngraph_error("Coordinate iterator inner corner is out of bounds");
}
if (window_outer_corner[i] > m_space_shape[i])
{
throw ngraph_error("Coordinate iterator outer corner is out of bounds");
}
if (m_strides[i] == 0)
{
throw ngraph_error("Coordinate iterator stride is zero");
}
}
}
CoordinateIterator::CoordinateIterator(const Shape& space_shape)
: CoordinateIterator(space_shape,
Strides(space_shape.size(), 1),
space_shape,
Coordinate(space_shape.size(), 0))
{
}
CoordinateIterator::CoordinateIterator(const Shape& space_shape, const Strides& strides)
: CoordinateIterator(space_shape, strides, space_shape, Coordinate(space_shape.size(), 0))
{
}
size_t CoordinateIterator::get_current_index() const
{
size_t index = 0;
size_t stride = 1;
for (size_t i = m_space_shape.size(); i-- > 0;)
{
index += m_current_coordinate[i] * stride;
stride *= m_space_shape[i];
}
return index;
}
bool CoordinateIterator::increment()
{
bool overflow = true;
for (size_t i = m_space_shape.size(); i-- > 0;)
{
m_current_coordinate[i] += m_strides[i];
if (m_current_coordinate[i] >= m_window_outer_corner[i])
{
m_current_coordinate[i] = m_window_inner_corner[i];
}
else
{
overflow = false;
break;
}
}
return !overflow;
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <cstdio>
#include <iostream>
#include <vector>
#include "ngraph/common.hpp"
namespace ngraph
{
class CoordinateIterator
{
public:
CoordinateIterator(const Shape& space_shape,
const Strides& strides,
const Coordinate& window_outer_corner,
const Coordinate& window_inner_corner);
CoordinateIterator(const Shape& space_shape);
CoordinateIterator(const Shape& space_shape, const Strides& strides);
Coordinate get_current_coordinate() const { return m_current_coordinate; }
size_t get_current_index() const;
bool increment();
private:
const Shape m_space_shape;
const Strides m_strides;
const Coordinate m_window_outer_corner;
const Coordinate m_window_inner_corner;
Coordinate m_current_coordinate;
};
}
......@@ -44,6 +44,7 @@
#include "ngraph/builder/autobroadcast.hpp"
#include "ngraph/builder/reduce_ops.hpp"
#include "ngraph/common.hpp"
#include "ngraph/coordinate_iterator.hpp"
#include "ngraph/descriptor/buffer.hpp"
#include "ngraph/descriptor/input.hpp"
#include "ngraph/descriptor/layout/dense_tensor_view_layout.hpp"
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <cmath>
#include "ngraph/common.hpp"
#include "ngraph/coordinate_iterator.hpp"
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void broadcast(T* arg,
T* out,
const Shape& in_shape,
const Shape& out_shape,
const AxisSet& broadcast_axes)
{
// For the outer loop we will walk over the entire input shape.
CoordinateIterator arg_iter(in_shape);
do
{
// For the inner loop we will walk across the entire axis for the new broadcast axes, and stay put at the current arg position for the existing axes.
Coordinate arg_coordinate = arg_iter.get_current_coordinate();
Strides out_strides(out_shape.size(), 1);
Coordinate out_outer_corner(out_shape.size());
Coordinate out_inner_corner(out_shape.size());
size_t arg_pos = 0;
for (size_t i = 0; i < out_shape.size(); i++)
{
if (broadcast_axes.find(i) == broadcast_axes.end())
{
// This is an existing axis.
out_outer_corner[i] = arg_coordinate[arg_pos];
out_inner_corner[i] = arg_coordinate[arg_pos];
arg_pos++;
}
else
{
// This is a new broadcast axis.
out_outer_corner[i] = out_shape[i];
out_inner_corner[i] = 0;
}
}
CoordinateIterator out_iter(
out_shape, out_strides, out_outer_corner, out_inner_corner);
do
{
out[out_iter.get_current_index()] = arg[arg_iter.get_current_index()];
} while (out_iter.increment());
} while (arg_iter.increment());
}
}
}
}
......@@ -96,6 +96,7 @@
#include "ngraph/runtime/ngvm/instruction/add.hpp"
#include "ngraph/runtime/ngvm/instruction/asin.hpp"
#include "ngraph/runtime/ngvm/instruction/atan.hpp"
#include "ngraph/runtime/ngvm/instruction/broadcast.hpp"
#include "ngraph/runtime/ngvm/instruction/call.hpp"
#include "ngraph/runtime/ngvm/instruction/ceiling.hpp"
#include "ngraph/runtime/ngvm/instruction/constant.hpp"
......@@ -420,15 +421,23 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
auto arg_tensor_type = dynamic_pointer_cast<const TensorViewType>(
n->get_arguments().at(0)->get_value_type());
assert(nullptr != arg_tensor_type);
auto arg_shape = arg_tensor_type->get_shape();
auto result_tensor_type =
dynamic_pointer_cast<const TensorViewType>(n->get_value_type());
assert(nullptr != result_tensor_type);
auto arg_shape = arg_tensor_type->get_shape();
auto result_shape = result_tensor_type->get_shape();
auto& result_element_type = result_tensor_type->get_element_type();
PUSH_POLYMORPHIC_INSTRUCTION(result_element_type,
"Broadcast has unhandled element type",
instruction::BroadcastInstruction,
in[0],
out[0],
arg_shape,
result_shape,
broadcast->get_broadcast_axes());
/*
if (broadcast->get_broadcast_axes().empty())
{
PUSH_POLYMORPHIC_INSTRUCTION(result_element_type,
......@@ -473,7 +482,7 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
else
{
throw ngraph_error("Broadcast not implemented for rank>2 in VM yet");
}
}*/
};
REGISTER_TO_OP_MAP(op::Concat)
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/kernel/broadcast.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace ngvm
{
namespace instruction
{
template <typename ET>
class BroadcastInstruction : public Instruction
{
public:
BroadcastInstruction(const TensorViewInfo& arg,
const TensorViewInfo& out,
const Shape& arg_shape,
const Shape& out_shape,
const AxisSet& broadcast_axes)
: m_arg(arg)
, m_out(out)
, m_arg_shape(arg_shape)
, m_out_shape(out_shape)
, m_broadcast_axes(broadcast_axes)
{
}
virtual void execute(CallFrame& call_frame) const override
{
typename ET::type* arg = get_tensor_data_ptr<ET>(call_frame, m_arg);
typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
kernel::broadcast<typename ET::type>(
arg, out, m_arg_shape, m_out_shape, m_broadcast_axes);
}
protected:
TensorViewInfo m_arg;
TensorViewInfo m_out;
Shape m_arg_shape;
Shape m_out_shape;
AxisSet m_broadcast_axes;
};
}
}
}
}
......@@ -22,10 +22,11 @@ include_directories(
)
set (SRC
autodiff.cpp
builder_autobroadcast.cpp
builder_reduce_ops.cpp
autodiff.cpp
build_graph.cpp
coordinate_iterator.cpp
copy.cpp
eigen.cpp
element_type.cpp
......
......@@ -1512,6 +1512,78 @@ TEST(${BACKEND_NAME}, broadcast_vector_rowwise_int64)
result->get_vector<element::Int64::type>());
}
TEST(${BACKEND_NAME}, broadcast_matrix_0)
{
auto shape_a = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::Float32::element_type(), shape_a);
auto shape_r = Shape{2, 2, 2};
auto rt = make_shared<TensorViewType>(element::Float32::element_type(), shape_r);
auto f = make_shared<Function>(
make_shared<op::Broadcast>(A, shape_r, AxisSet{0}), rt, op::Parameters{A});
auto manager = runtime::Manager::get("${BACKEND_NAME}");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
// Create some tensors for input/output
auto a = backend->make_primary_tensor_view(element::Float32::element_type(), shape_a);
copy_data(a, vector<element::Float32::type>{1, 2, 3, 4});
auto result = backend->make_primary_tensor_view(element::Float32::element_type(), shape_r);
cf->call({a}, {result});
ASSERT_EQ((vector<element::Float32::type>{1, 2, 3, 4, 1, 2, 3, 4}),
result->get_vector<element::Float32::type>());
}
TEST(${BACKEND_NAME}, broadcast_matrix_1)
{
auto shape_a = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::Float32::element_type(), shape_a);
auto shape_r = Shape{2, 2, 2};
auto rt = make_shared<TensorViewType>(element::Float32::element_type(), shape_r);
auto f = make_shared<Function>(
make_shared<op::Broadcast>(A, shape_r, AxisSet{1}), rt, op::Parameters{A});
auto manager = runtime::Manager::get("${BACKEND_NAME}");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
// Create some tensors for input/output
auto a = backend->make_primary_tensor_view(element::Float32::element_type(), shape_a);
copy_data(a, vector<element::Float32::type>{1, 2, 3, 4});
auto result = backend->make_primary_tensor_view(element::Float32::element_type(), shape_r);
cf->call({a}, {result});
ASSERT_EQ((vector<element::Float32::type>{1, 2, 1, 2, 3, 4, 3, 4}),
result->get_vector<element::Float32::type>());
}
TEST(${BACKEND_NAME}, broadcast_matrix_2)
{
auto shape_a = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::Float32::element_type(), shape_a);
auto shape_r = Shape{2, 2, 2};
auto rt = make_shared<TensorViewType>(element::Float32::element_type(), shape_r);
auto f = make_shared<Function>(
make_shared<op::Broadcast>(A, shape_r, AxisSet{2}), rt, op::Parameters{A});
auto manager = runtime::Manager::get("${BACKEND_NAME}");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
// Create some tensors for input/output
auto a = backend->make_primary_tensor_view(element::Float32::element_type(), shape_a);
copy_data(a, vector<element::Float32::type>{1, 2, 3, 4});
auto result = backend->make_primary_tensor_view(element::Float32::element_type(), shape_r);
cf->call({a}, {result});
ASSERT_EQ((vector<element::Float32::type>{1, 1, 2, 2, 3, 3, 4, 4}),
result->get_vector<element::Float32::type>());
}
TEST(${BACKEND_NAME}, convert_int32_float32)
{
auto shape = Shape{2, 2};
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include <memory>
using namespace std;
using namespace ngraph;
TEST(coordinate_iterator, construct)
{
Shape space_shape{2, 3, 5, 6};
Strides strides{1, 1, 1, 1};
Coordinate window_outer_corner{2, 3, 5, 6};
Coordinate window_inner_corner{0, 0, 0, 0};
auto ci = CoordinateIterator(space_shape, strides, window_outer_corner, window_inner_corner);
}
TEST(coordinate_iterator, construct_defaults)
{
Shape space_shape{2, 3, 5, 6};
Strides strides{2, 2, 2, 1};
auto ci = CoordinateIterator(space_shape, strides);
}
TEST(coordinate_iterator, construct_defaults_stride)
{
Shape space_shape{2, 3, 5, 6};
auto ci = CoordinateIterator(space_shape);
}
TEST(coordinate_iterator, construct_bad_outer_oob)
{
Shape space_shape{2, 3, 5, 6};
Strides strides{1, 1, 1, 1};
Coordinate window_outer_corner{2, 4, 5, 6};
Coordinate window_inner_corner{0, 0, 0, 0};
EXPECT_ANY_THROW({
auto ci =
CoordinateIterator(space_shape, strides, window_outer_corner, window_inner_corner);
});
}
TEST(coordinate_iterator, construct_bad_inner_oob)
{
Shape space_shape{2, 3, 5, 6};
Strides strides{1, 1, 1, 1};
Coordinate window_outer_corner{2, 3, 5, 6};
Coordinate window_inner_corner{0, 3, 0, 0};
EXPECT_ANY_THROW({
auto ci =
CoordinateIterator(space_shape, strides, window_outer_corner, window_inner_corner);
});
}
TEST(coordinate_iterator, construct_bad_inner_outside_outer)
{
Shape space_shape{2, 3, 5, 6};
Strides strides{1, 1, 1, 1};
Coordinate window_outer_corner{2, 1, 5, 6};
Coordinate window_inner_corner{0, 2, 0, 0};
EXPECT_ANY_THROW({
auto ci =
CoordinateIterator(space_shape, strides, window_outer_corner, window_inner_corner);
});
}
TEST(coordinate_iterator, construct_bad_zero_stride)
{
Shape space_shape{2, 3, 5, 6};
Strides strides{1, 0, 1, 1};
Coordinate window_outer_corner{2, 3, 5, 6};
Coordinate window_inner_corner{0, 0, 0, 0};
EXPECT_ANY_THROW({
auto ci =
CoordinateIterator(space_shape, strides, window_outer_corner, window_inner_corner);
});
}
TEST(coordinate_iterator, cover_count_defaults)
{
Shape space_shape{2, 3, 5, 6};
auto ci = CoordinateIterator(space_shape);
size_t count = 0;
size_t expected_index = 0;
do
{
count++;
EXPECT_EQ(ci.get_current_index(), expected_index);
expected_index++;
} while (ci.increment());
EXPECT_EQ(count, 2 * 3 * 5 * 6);
}
TEST(coordinate_iterator, cover_count_stride_2)
{
Shape space_shape{2, 3, 5, 6};
Strides strides{1, 1, 1, 2};
auto ci = CoordinateIterator(space_shape, strides);
size_t count = 0;
size_t expected_index = 0;
do
{
count++;
EXPECT_EQ(ci.get_current_index(), expected_index);
expected_index += 2;
} while (ci.increment());
EXPECT_EQ(count, 2 * 3 * 5 * 6 / 2);
}
#define CEIL_DIV(x, y) (1 + (((x)-1) / (y)))
TEST(coordinate_iterator, cover_count_stride_uneven)
{
Shape space_shape{2, 3, 5, 6};
Strides strides{1, 2, 2, 3};
auto ci = CoordinateIterator(space_shape, strides);
size_t count = 0;
do
{
count++;
} while (ci.increment());
EXPECT_EQ(count, CEIL_DIV(2, 1) * CEIL_DIV(3, 2) * CEIL_DIV(5, 2) * CEIL_DIV(6, 3));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment