Commit 158de495 authored by Adam Procter's avatar Adam Procter Committed by Scott Cyphers

Brute-force broadcasting (#161)

* Implement broadcasting of scalars through VM

* Implement vector-to-matrix broadcasting through VM

* Add missing files

* Address review comments
parent c3113593
......@@ -39,6 +39,8 @@ namespace ngraph
virtual std::string description() const override { return "Broadcast"; }
virtual void propagate_types() override;
const AxisSet& get_broadcast_axes() const { return m_broadcast_axes; }
protected:
Shape m_shape;
AxisSet m_broadcast_axes;
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace eigen
{
template <typename ET>
class BroadcastScalarInstruction : public Instruction
{
public:
BroadcastScalarInstruction(const TensorViewInfo& arg,
const TensorViewInfo& out)
: m_arg(arg)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
// This is a bit hacky: regardless of the tensor rank we
// pull it out as a vector. This works because of the way
// fmt::V computes sizes---it lumps together any higher
// dimensions---while fmt::M ignores them.
EigenArray1d<ET>(call_frame, m_out) = EigenArray1d<ET>(call_frame, m_arg)(0,0);
}
protected:
TensorViewInfo m_arg;
TensorViewInfo m_out;
};
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace eigen
{
template <typename ET>
class BroadcastVectorColwiseInstruction : public Instruction
{
public:
BroadcastVectorColwiseInstruction(const TensorViewInfo& arg,
const TensorViewInfo& out)
: m_arg(arg)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
EigenMatrix<ET>(call_frame, m_out).colwise() =
EigenVector<ET>(call_frame, m_arg);
}
protected:
TensorViewInfo m_arg;
TensorViewInfo m_out;
};
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace eigen
{
template <typename ET>
class BroadcastVectorRowwiseInstruction : public Instruction
{
public:
BroadcastVectorRowwiseInstruction(const TensorViewInfo& arg,
const TensorViewInfo& out)
: m_arg(arg)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
EigenMatrix<ET>(call_frame, m_out).rowwise() =
EigenVector<ET>(call_frame, m_arg).transpose();
}
protected:
TensorViewInfo m_arg;
TensorViewInfo m_out;
};
}
}
}
......@@ -25,6 +25,7 @@
#include "ngraph/node.hpp"
#include "ngraph/ops/abs.hpp"
#include "ngraph/ops/add.hpp"
#include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/concatenate.hpp"
#include "ngraph/ops/constant.hpp"
#include "ngraph/ops/divide.hpp"
......@@ -48,6 +49,9 @@
#include "ngraph/pass/topological_sort.hpp"
#include "ngraph/runtime/eigen/abs.hpp"
#include "ngraph/runtime/eigen/add.hpp"
#include "ngraph/runtime/eigen/broadcast_scalar.hpp"
#include "ngraph/runtime/eigen/broadcast_vector_colwise.hpp"
#include "ngraph/runtime/eigen/broadcast_vector_rowwise.hpp"
#include "ngraph/runtime/eigen/call.hpp"
#include "ngraph/runtime/eigen/concat_matrix.hpp"
#include "ngraph/runtime/eigen/concat_vector.hpp"
......@@ -138,6 +142,59 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
dynamic_cast<const op::TensorConstant<element::Float32>*>(n)->get_value()->get_vector(),
out[0]);
REGISTER_TO_OP_MAP(op::Broadcast)
{
auto broadcast = static_cast<const op::Broadcast*>(n);
auto arg_tensor_type =
dynamic_pointer_cast<const TensorViewType>(n->get_arguments().at(0)->get_value_type());
assert(nullptr != arg_tensor_type);
auto result_tensor_type =
dynamic_pointer_cast<const TensorViewType>(n->get_value_type());
assert(nullptr != result_tensor_type);
auto arg_shape = arg_tensor_type->get_shape();
auto result_shape = result_tensor_type->get_shape();
if (broadcast->get_broadcast_axes().empty())
{
// Degenerate case: no broadcast axes is just a copy.
ef->get_instructions()->push_back(
make_shared<runtime::eigen::CopyInstruction<element::Float32>>(
in[0].get_index(), out[0].get_index()));
}
else if (arg_shape.size() == 0)
{
ef->get_instructions()->push_back(
make_shared<runtime::eigen::BroadcastScalarInstruction<element::Float32>>(
in[0], out[0]));
}
else if (arg_shape.size() == 1 && result_shape.size() == 2)
{
if (broadcast->get_broadcast_axes() == AxisSet{1})
{
ef->get_instructions()->push_back(
make_shared<runtime::eigen::BroadcastVectorColwiseInstruction<element::Float32>>(
in[0], out[0]));
}
else if (broadcast->get_broadcast_axes() == AxisSet{0})
{
ef->get_instructions()->push_back(
make_shared<runtime::eigen::BroadcastVectorRowwiseInstruction<element::Float32>>(
in[0], out[0]));
}
else
{
throw ngraph_error("Internal error: axis set for vector-matrix broadcast is neither {0} or {1}");
}
}
else
{
throw ngraph_error("Broadcast not implemented for rank>2 in VM yet");
}
};
REGISTER_TO_OP_MAP(op::Concat)
{
auto result_tensor_type =
......
......@@ -669,3 +669,122 @@ TEST(execute, test_function_call)
(*cf)({x, z, y}, {result});
ASSERT_EQ((vector<float>{100, 144, 196, 256}), result->get_vector());
}
TEST(execute, test_broadcast_scalar_vector)
{
auto shape_a = Shape{};
auto A = make_shared<op::Parameter>(element::Float32::element_type(), shape_a);
auto shape_r = Shape{4};
auto rt = make_shared<TensorViewType>(element::Float32::element_type(), shape_r);
auto f = make_shared<Function>(make_shared<op::Broadcast>(A, shape_r, AxisSet{0}), rt, op::Parameters{A});
auto external = make_shared<ngraph::runtime::ExternalFunction>(f);
auto cf = external->make_call_frame();
// Create some tensors for input/output
auto a = ngraph::runtime::make_tensor<element::Float32>(shape_a);
*a = vector<float>{6};
auto result = ngraph::runtime::make_tensor<element::Float32>(shape_r);
(*cf)({a}, {result});
ASSERT_EQ((vector<float>{6, 6, 6, 6}), result->get_vector());
}
TEST(execute, test_broadcast_scalar_matrix)
{
auto shape_a = Shape{};
auto A = make_shared<op::Parameter>(element::Float32::element_type(), shape_a);
auto shape_r = Shape{2,2};
auto rt = make_shared<TensorViewType>(element::Float32::element_type(), shape_r);
auto f = make_shared<Function>(make_shared<op::Broadcast>(A, shape_r, AxisSet{0,1}), rt, op::Parameters{A});
auto external = make_shared<ngraph::runtime::ExternalFunction>(f);
auto cf = external->make_call_frame();
// Create some tensors for input/output
auto a = ngraph::runtime::make_tensor<element::Float32>(shape_a);
*a = vector<float>{6};
auto result = ngraph::runtime::make_tensor<element::Float32>(shape_r);
(*cf)({a}, {result});
ASSERT_EQ((vector<float>{6, 6, 6, 6}), result->get_vector());
}
TEST(execute, test_broadcast_scalar_tensor)
{
auto shape_a = Shape{};
auto A = make_shared<op::Parameter>(element::Float32::element_type(), shape_a);
auto shape_r = Shape{2,2,2};
auto rt = make_shared<TensorViewType>(element::Float32::element_type(), shape_r);
auto f = make_shared<Function>(make_shared<op::Broadcast>(A, shape_r, AxisSet{0,1,2}), rt, op::Parameters{A});
auto external = make_shared<ngraph::runtime::ExternalFunction>(f);
auto cf = external->make_call_frame();
// Create some tensors for input/output
auto a = ngraph::runtime::make_tensor<element::Float32>(shape_a);
*a = vector<float>{6};
auto result = ngraph::runtime::make_tensor<element::Float32>(shape_r);
(*cf)({a}, {result});
ASSERT_EQ((vector<float>{6, 6, 6, 6, 6, 6, 6, 6}), result->get_vector());
}
TEST(execute, test_broadcast_trivial)
{
auto shape = Shape{2,2,2};
auto A = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto rt = make_shared<TensorViewType>(element::Float32::element_type(), shape);
auto f = make_shared<Function>(make_shared<op::Broadcast>(A, shape, AxisSet{}), rt, op::Parameters{A});
auto external = make_shared<ngraph::runtime::ExternalFunction>(f);
auto cf = external->make_call_frame();
// Create some tensors for input/output
auto a = ngraph::runtime::make_tensor<element::Float32>(shape);
*a = vector<float>{2, 4, 6, 8, 16, 32, 64, 128};
auto result = ngraph::runtime::make_tensor<element::Float32>(shape);
(*cf)({a}, {result});
ASSERT_EQ((vector<float>{2, 4, 6, 8, 16, 32, 64, 128}), result->get_vector());
}
TEST(execute, test_broadcast_vector_colwise)
{
auto shape_a = Shape{3};
auto A = make_shared<op::Parameter>(element::Float32::element_type(), shape_a);
auto shape_r = Shape{3,4};
auto rt = make_shared<TensorViewType>(element::Float32::element_type(), shape_r);
auto f = make_shared<Function>(make_shared<op::Broadcast>(A, shape_r, AxisSet{1}), rt, op::Parameters{A});
auto external = make_shared<ngraph::runtime::ExternalFunction>(f);
auto cf = external->make_call_frame();
// Create some tensors for input/output
auto a = ngraph::runtime::make_tensor<element::Float32>(shape_a);
*a = vector<float>{1,2,3};
auto result = ngraph::runtime::make_tensor<element::Float32>(shape_r);
(*cf)({a}, {result});
ASSERT_EQ((vector<float>{1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3}), result->get_vector());
}
TEST(execute, test_broadcast_vector_rowwise)
{
auto shape_a = Shape{4};
auto A = make_shared<op::Parameter>(element::Float32::element_type(), shape_a);
auto shape_r = Shape{3,4};
auto rt = make_shared<TensorViewType>(element::Float32::element_type(), shape_r);
auto f = make_shared<Function>(make_shared<op::Broadcast>(A, shape_r, AxisSet{0}), rt, op::Parameters{A});
auto external = make_shared<ngraph::runtime::ExternalFunction>(f);
auto cf = external->make_call_frame();
// Create some tensors for input/output
auto a = ngraph::runtime::make_tensor<element::Float32>(shape_a);
*a = vector<float>{1,2,3,4};
auto result = ngraph::runtime::make_tensor<element::Float32>(shape_r);
(*cf)({a}, {result});
ASSERT_EQ((vector<float>{1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4}), result->get_vector());
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment