Commit 953b8c59 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Merge pull request #115 from NervanaSystems/cyphers/eigenfunction

VM for interpreting ops
parents c19a5319 abff9517
......@@ -22,9 +22,10 @@
namespace ngraph
{
class Node;
namespace op {
namespace op
{
class Parameter;
/// A list of parameters
using Parameters = std::vector<std::shared_ptr<Parameter>>;
}
......@@ -47,5 +48,4 @@ namespace ngraph
/// Strides of a tensor
using Strides = std::vector<size_t>;
}
......@@ -55,9 +55,13 @@
#include "ngraph/ops/remainder.hpp"
#include "ngraph/ops/subtract.hpp"
#include "ngraph/ops/tuple.hpp"
#include "ngraph/runtime/eigen/add.hpp"
#include "ngraph/runtime/eigen/multiply.hpp"
#include "ngraph/runtime/eigen/return.hpp"
#include "ngraph/runtime/eigen/tensor_view.hpp"
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/function.hpp"
#include "ngraph/function.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/type.hpp"
......@@ -12,18 +12,41 @@
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <algorithm>
#include "ngraph/ngraph.hpp"
using namespace std;
using namespace ngraph;
using namespace runtime;
using namespace ngraph::runtime;
CallFrame::CallFrame(size_t n_inputs,
size_t n_outputs,
const PTVs& temps,
size_t initial_pc,
const shared_ptr<vector<shared_ptr<Instruction>>>& instructions)
: m_n_inputs(n_inputs)
, m_n_outputs(n_outputs)
, m_tensors(n_inputs + n_outputs + temps.size())
, m_initial_pc(initial_pc)
, m_instructions(instructions)
{
copy(temps.begin(), temps.end(), m_tensors.begin() + m_n_inputs + m_n_outputs);
}
CallFrame::CallFrame(Function& function,
const std::vector<std::shared_ptr<PrimaryTensorView>>& arguments,
const std::vector<std::shared_ptr<PrimaryTensorView>>& results)
void CallFrame::operator()(const PTVs& inputs, const PTVs& outputs)
{
m_tensors.insert(m_tensors.end(), arguments.begin(), arguments.end());
m_tensors.insert(m_tensors.end(), results.begin(), results.end());
// TBD
// From Function allocate tensors for the temporaries
copy(inputs.begin(), inputs.end(), m_tensors.begin());
copy(outputs.begin(), outputs.end(), m_tensors.begin() + m_n_inputs);
m_next_pc = m_initial_pc;
m_return = false;
while (!m_return)
{
m_pc = m_next_pc;
m_next_pc = m_pc + 1;
m_instructions->at(m_pc)->execute(*this);
}
// Don't hold onto inputs/outputs
fill_n(m_tensors.begin(), m_n_inputs + m_n_outputs, nullptr);
}
......@@ -17,43 +17,41 @@
#include <memory>
#include <vector>
#include "ngraph/runtime/function.hpp"
#include "ngraph/function.hpp"
#include "ngraph/runtime/instruction.hpp"
namespace ngraph
{
namespace runtime
{
class CallFrameAccessor;
using PTVs = std::vector<std::shared_ptr<ngraph::runtime::PrimaryTensorView>>;
// This is constructed when a runtime function is called.
class CallFrame
{
friend class CallFrameAccessor;
public:
CallFrame(Function& function,
const std::vector<std::shared_ptr<PrimaryTensorView>>& arguments,
const std::vector<std::shared_ptr<PrimaryTensorView>>& results);
class PrimaryTensorView;
protected:
std::vector<std::shared_ptr<PrimaryTensorView>> m_tensors;
};
class CallFrameAccessor
// A VM for executing lightly-compiled graph functions.
class CallFrame
{
public:
CallFrameAccessor(size_t index)
: m_index(index)
{
}
CallFrame(
size_t n_inputs,
size_t n_outputs,
const PTVs& temps,
size_t initial_pc,
const std::shared_ptr<std::vector<std::shared_ptr<Instruction>>>& instructions);
std::shared_ptr<PrimaryTensorView> operator()(CallFrame& call_frame)
{
return call_frame.m_tensors[m_index];
}
void operator()(const PTVs& inputs, const PTVs& outpus);
void set_return() { m_return = true; }
std::shared_ptr<PrimaryTensorView> get_tensor(size_t i) { return m_tensors[i]; }
protected:
size_t m_index;
size_t m_n_inputs;
size_t m_n_outputs;
PTVs m_tensors;
size_t m_initial_pc;
std::shared_ptr<std::vector<std::shared_ptr<Instruction>>> m_instructions;
size_t m_pc;
size_t m_next_pc;
bool m_return;
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/instruction.hpp"
namespace ngraph
{
namespace runtime
{
namespace eigen
{
template <typename ET>
class AddInstruction : public Instruction
{
public:
AddInstruction(size_t arg0, size_t arg1, size_t out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
dynamic_cast<PrimaryTensorView<ET>*>(&*call_frame.get_tensor(m_out))
->get_map() =
dynamic_cast<PrimaryTensorView<ET>*>(&*call_frame.get_tensor(m_arg0))
->get_map() +
dynamic_cast<PrimaryTensorView<ET>*>(&*call_frame.get_tensor(m_arg1))
->get_map();
}
protected:
size_t m_arg0;
size_t m_arg1;
size_t m_out;
};
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/instruction.hpp"
namespace ngraph
{
namespace runtime
{
namespace eigen
{
template <typename ET>
class MultiplyInstruction : public Instruction
{
public:
MultiplyInstruction(size_t arg0, size_t arg1, size_t out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
dynamic_cast<PrimaryTensorView<ET>*>(&*call_frame.get_tensor(m_out))
->get_map() =
dynamic_cast<PrimaryTensorView<ET>*>(&*call_frame.get_tensor(m_arg0))
->get_map() *
dynamic_cast<PrimaryTensorView<ET>*>(&*call_frame.get_tensor(m_arg1))
->get_map();
}
protected:
size_t m_arg0;
size_t m_arg1;
size_t m_out;
};
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/instruction.hpp"
namespace ngraph
{
namespace runtime
{
namespace eigen
{
class ReturnInstruction : public Instruction
{
public:
ReturnInstruction() {}
virtual void execute(CallFrame& call_frame) const override
{
call_frame.set_return();
}
};
}
}
}
......@@ -14,24 +14,18 @@
#pragma once
#include <memory>
#include <vector>
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
// A compiled graph function
class Function
class CallFrame;
// An interpreter for an Op
class Instruction
{
public:
virtual ~Function() {}
// Invoke the function with a the given inputs and outputs
void operator()(std::vector<std::shared_ptr<PrimaryTensorView>> inputs,
std::vector<std::shared_ptr<PrimaryTensorView>> outputs);
virtual ~Instruction(){}
virtual void execute(CallFrame& call_frame) const = 0;
};
}
}
......@@ -21,6 +21,7 @@
using namespace std;
using namespace ngraph;
using namespace ngraph::runtime;
using namespace ngraph::runtime::eigen;
TEST(runtime, test_add)
......@@ -44,3 +45,45 @@ TEST(runtime, test_multiply)
multiply(*x->get_value(), *y->get_value(), *z->get_value());
ASSERT_EQ((vector<float>{5, 12, 21, 32}), z->get_value()->get_vector());
}
TEST(runtime, test_add_multiply)
{
// Inputs:
// 0 : a
// 1 : b
// 2 : c
// Outputs:
// 3 : result
// Temporaries
// 4: t0
auto instructions = make_shared<std::vector<std::shared_ptr<ngraph::runtime::Instruction>>>();
// a + b -> t0
instructions->push_back(make_shared<AddInstruction<element::Float32>>(0, 1, 4));
// t0 * c -> result
instructions->push_back(make_shared<MultiplyInstruction<element::Float32>>(4, 2, 3));
instructions->push_back(make_shared<ReturnInstruction>());
runtime::CallFrame cf{3,
1,
PTVs{make_shared<PrimaryTensorView<element::Float32>>(Shape{2, 2})},
0,
instructions};
// Create some tensors for input/output
auto a = make_shared<PrimaryTensorView<element::Float32>>(Shape{2, 2});
*a = vector<float>{1, 2, 3, 4};
auto b = make_shared<PrimaryTensorView<element::Float32>>(Shape{2, 2});
*b = vector<float>{5, 6, 7, 8};
auto c = make_shared<PrimaryTensorView<element::Float32>>(Shape{2, 2});
*c = vector<float>{9, 10, 11, 12};
auto result = make_shared<PrimaryTensorView<element::Float32>>(Shape{2, 2});
cf(PTVs{a, b, c}, PTVs{result});
ASSERT_EQ((vector<float>{54, 80, 110, 144}), result->get_vector());
cf(PTVs{b, a, c}, PTVs{result});
ASSERT_EQ((vector<float>{54, 80, 110, 144}), result->get_vector());
cf(PTVs{a, c, b}, PTVs{result});
ASSERT_EQ((vector<float>{50, 72, 98, 128}), result->get_vector());
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment