Commit abff9517 authored by Scott Cyphers's avatar Scott Cyphers

VM for interpreting ops

parent 36cc0317
...@@ -22,9 +22,10 @@ ...@@ -22,9 +22,10 @@
namespace ngraph namespace ngraph
{ {
class Node; class Node;
namespace op { namespace op
{
class Parameter; class Parameter;
/// A list of parameters /// A list of parameters
using Parameters = std::vector<std::shared_ptr<Parameter>>; using Parameters = std::vector<std::shared_ptr<Parameter>>;
} }
...@@ -47,5 +48,4 @@ namespace ngraph ...@@ -47,5 +48,4 @@ namespace ngraph
/// Strides of a tensor /// Strides of a tensor
using Strides = std::vector<size_t>; using Strides = std::vector<size_t>;
} }
...@@ -44,9 +44,13 @@ ...@@ -44,9 +44,13 @@
#include "ngraph/ops/parameter.hpp" #include "ngraph/ops/parameter.hpp"
#include "ngraph/ops/subtract.hpp" #include "ngraph/ops/subtract.hpp"
#include "ngraph/ops/tuple.hpp" #include "ngraph/ops/tuple.hpp"
#include "ngraph/runtime/eigen/add.hpp"
#include "ngraph/runtime/eigen/multiply.hpp"
#include "ngraph/runtime/eigen/return.hpp"
#include "ngraph/runtime/eigen/tensor_view.hpp" #include "ngraph/runtime/eigen/tensor_view.hpp"
#include "ngraph/runtime/call_frame.hpp" #include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/function.hpp" #include "ngraph/function.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp" #include "ngraph/runtime/tensor_view.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "ngraph/type.hpp" #include "ngraph/type.hpp"
...@@ -12,18 +12,41 @@ ...@@ -12,18 +12,41 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include <algorithm>
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
using namespace runtime; using namespace ngraph::runtime;
CallFrame::CallFrame(size_t n_inputs,
size_t n_outputs,
const PTVs& temps,
size_t initial_pc,
const shared_ptr<vector<shared_ptr<Instruction>>>& instructions)
: m_n_inputs(n_inputs)
, m_n_outputs(n_outputs)
, m_tensors(n_inputs + n_outputs + temps.size())
, m_initial_pc(initial_pc)
, m_instructions(instructions)
{
copy(temps.begin(), temps.end(), m_tensors.begin() + m_n_inputs + m_n_outputs);
}
CallFrame::CallFrame(Function& function, void CallFrame::operator()(const PTVs& inputs, const PTVs& outputs)
const std::vector<std::shared_ptr<PrimaryTensorView>>& arguments,
const std::vector<std::shared_ptr<PrimaryTensorView>>& results)
{ {
m_tensors.insert(m_tensors.end(), arguments.begin(), arguments.end()); copy(inputs.begin(), inputs.end(), m_tensors.begin());
m_tensors.insert(m_tensors.end(), results.begin(), results.end()); copy(outputs.begin(), outputs.end(), m_tensors.begin() + m_n_inputs);
// TBD m_next_pc = m_initial_pc;
// From Function allocate tensors for the temporaries m_return = false;
while (!m_return)
{
m_pc = m_next_pc;
m_next_pc = m_pc + 1;
m_instructions->at(m_pc)->execute(*this);
}
// Don't hold onto inputs/outputs
fill_n(m_tensors.begin(), m_n_inputs + m_n_outputs, nullptr);
} }
...@@ -17,43 +17,41 @@ ...@@ -17,43 +17,41 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "ngraph/runtime/function.hpp" #include "ngraph/function.hpp"
#include "ngraph/runtime/instruction.hpp"
namespace ngraph namespace ngraph
{ {
namespace runtime namespace runtime
{ {
class CallFrameAccessor; using PTVs = std::vector<std::shared_ptr<ngraph::runtime::PrimaryTensorView>>;
// This is constructed when a runtime function is called. class PrimaryTensorView;
class CallFrame
{
friend class CallFrameAccessor;
public:
CallFrame(Function& function,
const std::vector<std::shared_ptr<PrimaryTensorView>>& arguments,
const std::vector<std::shared_ptr<PrimaryTensorView>>& results);
protected: // A VM for executing lightly-compiled graph functions.
std::vector<std::shared_ptr<PrimaryTensorView>> m_tensors; class CallFrame
};
class CallFrameAccessor
{ {
public: public:
CallFrameAccessor(size_t index) CallFrame(
: m_index(index) size_t n_inputs,
{ size_t n_outputs,
} const PTVs& temps,
size_t initial_pc,
const std::shared_ptr<std::vector<std::shared_ptr<Instruction>>>& instructions);
std::shared_ptr<PrimaryTensorView> operator()(CallFrame& call_frame) void operator()(const PTVs& inputs, const PTVs& outpus);
{ void set_return() { m_return = true; }
return call_frame.m_tensors[m_index]; std::shared_ptr<PrimaryTensorView> get_tensor(size_t i) { return m_tensors[i]; }
}
protected: protected:
size_t m_index; size_t m_n_inputs;
size_t m_n_outputs;
PTVs m_tensors;
size_t m_initial_pc;
std::shared_ptr<std::vector<std::shared_ptr<Instruction>>> m_instructions;
size_t m_pc;
size_t m_next_pc;
bool m_return;
}; };
} }
} }
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/instruction.hpp"
namespace ngraph
{
namespace runtime
{
namespace eigen
{
template <typename ET>
class AddInstruction : public Instruction
{
public:
AddInstruction(size_t arg0, size_t arg1, size_t out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
dynamic_cast<PrimaryTensorView<ET>*>(&*call_frame.get_tensor(m_out))
->get_map() =
dynamic_cast<PrimaryTensorView<ET>*>(&*call_frame.get_tensor(m_arg0))
->get_map() +
dynamic_cast<PrimaryTensorView<ET>*>(&*call_frame.get_tensor(m_arg1))
->get_map();
}
protected:
size_t m_arg0;
size_t m_arg1;
size_t m_out;
};
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/instruction.hpp"
namespace ngraph
{
namespace runtime
{
namespace eigen
{
template <typename ET>
class MultiplyInstruction : public Instruction
{
public:
MultiplyInstruction(size_t arg0, size_t arg1, size_t out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
dynamic_cast<PrimaryTensorView<ET>*>(&*call_frame.get_tensor(m_out))
->get_map() =
dynamic_cast<PrimaryTensorView<ET>*>(&*call_frame.get_tensor(m_arg0))
->get_map() *
dynamic_cast<PrimaryTensorView<ET>*>(&*call_frame.get_tensor(m_arg1))
->get_map();
}
protected:
size_t m_arg0;
size_t m_arg1;
size_t m_out;
};
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/instruction.hpp"
namespace ngraph
{
namespace runtime
{
namespace eigen
{
class ReturnInstruction : public Instruction
{
public:
ReturnInstruction() {}
virtual void execute(CallFrame& call_frame) const override
{
call_frame.set_return();
}
};
}
}
}
...@@ -14,24 +14,18 @@ ...@@ -14,24 +14,18 @@
#pragma once #pragma once
#include <memory>
#include <vector>
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph namespace ngraph
{ {
namespace runtime namespace runtime
{ {
// A compiled graph function class CallFrame;
class Function
// An interpreter for an Op
class Instruction
{ {
public: public:
virtual ~Function() {} virtual ~Instruction(){}
virtual void execute(CallFrame& call_frame) const = 0;
// Invoke the function with a the given inputs and outputs
void operator()(std::vector<std::shared_ptr<PrimaryTensorView>> inputs,
std::vector<std::shared_ptr<PrimaryTensorView>> outputs);
}; };
} }
} }
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
using namespace ngraph::runtime;
using namespace ngraph::runtime::eigen; using namespace ngraph::runtime::eigen;
TEST(runtime, test_add) TEST(runtime, test_add)
...@@ -44,3 +45,45 @@ TEST(runtime, test_multiply) ...@@ -44,3 +45,45 @@ TEST(runtime, test_multiply)
multiply(*x->get_value(), *y->get_value(), *z->get_value()); multiply(*x->get_value(), *y->get_value(), *z->get_value());
ASSERT_EQ((vector<float>{5, 12, 21, 32}), z->get_value()->get_vector()); ASSERT_EQ((vector<float>{5, 12, 21, 32}), z->get_value()->get_vector());
} }
TEST(runtime, test_add_multiply)
{
// Inputs:
// 0 : a
// 1 : b
// 2 : c
// Outputs:
// 3 : result
// Temporaries
// 4: t0
auto instructions = make_shared<std::vector<std::shared_ptr<ngraph::runtime::Instruction>>>();
// a + b -> t0
instructions->push_back(make_shared<AddInstruction<element::Float32>>(0, 1, 4));
// t0 * c -> result
instructions->push_back(make_shared<MultiplyInstruction<element::Float32>>(4, 2, 3));
instructions->push_back(make_shared<ReturnInstruction>());
runtime::CallFrame cf{3,
1,
PTVs{make_shared<PrimaryTensorView<element::Float32>>(Shape{2, 2})},
0,
instructions};
// Create some tensors for input/output
auto a = make_shared<PrimaryTensorView<element::Float32>>(Shape{2, 2});
*a = vector<float>{1, 2, 3, 4};
auto b = make_shared<PrimaryTensorView<element::Float32>>(Shape{2, 2});
*b = vector<float>{5, 6, 7, 8};
auto c = make_shared<PrimaryTensorView<element::Float32>>(Shape{2, 2});
*c = vector<float>{9, 10, 11, 12};
auto result = make_shared<PrimaryTensorView<element::Float32>>(Shape{2, 2});
cf(PTVs{a, b, c}, PTVs{result});
ASSERT_EQ((vector<float>{54, 80, 110, 144}), result->get_vector());
cf(PTVs{b, a, c}, PTVs{result});
ASSERT_EQ((vector<float>{54, 80, 110, 144}), result->get_vector());
cf(PTVs{a, c, b}, PTVs{result});
ASSERT_EQ((vector<float>{50, 72, 98, 128}), result->get_vector());
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment