Commit 8895e895 authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

CPU: First-cut backend reusing most of the NGVM machinery

that will be incrementally replaced
parent 6da5d38d
......@@ -149,6 +149,10 @@ std::unique_ptr<llvm::Module> execution_state::compile(const string& source, con
LO->WChar = 1;
LO->RTTI = 1;
// CodeGen options
auto& CGO = Clang->getInvocation().getCodeGenOpts();
CGO.setDebugInfo(codegenoptions::FullDebugInfo);
// Map code filename to a memoryBuffer
StringRef source_ref(source);
unique_ptr<MemoryBuffer> buffer = MemoryBuffer::getMemBufferCopy(source_ref);
......
......@@ -20,13 +20,15 @@
using namespace std;
using namespace ngraph::runtime::cpu;
CallFrame::CallFrame(size_t n_inputs,
CallFrame::CallFrame(EntryPoint compiled_function,
size_t n_inputs,
size_t n_outputs,
const TensorViewPtrs& temps)
: m_n_inputs(n_inputs)
, m_n_outputs(n_outputs)
, m_tensor_views(n_inputs + n_outputs + temps.size())
, m_compiled_function(compiled_function)
{
copy(temps.begin(), temps.end(), m_tensor_views.begin() + m_n_inputs + m_n_outputs);
}
......@@ -39,6 +41,7 @@ void CallFrame::tensor_call(
copy(outputs.begin(), outputs.end(), m_tensor_views.begin() + m_n_inputs);
// TODO: Execute!
m_compiled_function(this, m_tensor_views);
// Don't hold onto inputs/outputs
fill_n(m_tensor_views.begin(), m_n_inputs + m_n_outputs, nullptr);
......
......@@ -16,6 +16,7 @@
#include <memory>
#include <vector>
#include <functional>
#include "ngraph/function.hpp"
#include "ngraph/runtime/call_frame.hpp"
......@@ -30,12 +31,15 @@ namespace ngraph
namespace cpu
{
class Instruction;
class CallFrame;
using EntryPoint = std::function<void(ngraph::runtime::cpu::CallFrame*, ngraph::runtime::TensorViewPtrs&)>;
// Compile and execute graphs
class CallFrame : public ngraph::runtime::CallFrame
{
public:
CallFrame(
EntryPoint compiled_function,
size_t n_inputs,
size_t n_outputs,
const TensorViewPtrs& temps);
......@@ -69,6 +73,7 @@ namespace ngraph
size_t m_n_outputs;
TensorViewPtrs m_tensor_views;
bool m_return;
EntryPoint m_compiled_function;
};
}
}
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <memory>
#include <Eigen/Dense>
#include "ngraph/descriptor/layout/dense_tensor_view_layout.hpp"
#include "ngraph/runtime/cpu/call_frame.hpp"
#include "ngraph/runtime/tensor_view_info.hpp"
namespace ngraph
{
namespace runtime
{
class TensorViewInfo;
namespace cpu
{
class CallFrame;
namespace eigen
{
using DynamicStrides = Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic>;
using VectorStrides = Eigen::Stride<Eigen::Dynamic, 1>;
template <typename ET>
using DynamicArray =
Eigen::Array<typename ET::type, Eigen::Dynamic, Eigen::Dynamic>;
template <typename ET>
using EigenArrayBase = Eigen::Map<DynamicArray<ET>, 0, DynamicStrides>;
template <typename ET>
using DynamicMatrix = Eigen::
Matrix<typename ET::type, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
template <typename ET>
using EigenMatrixBase = Eigen::Map<DynamicMatrix<ET>, 0, DynamicStrides>;
template <typename ET>
using DynamicVector = Eigen::Matrix<typename ET::type, Eigen::Dynamic, 1>;
template <typename ET>
using EigenVectorBase = Eigen::Map<DynamicVector<ET>, 0, VectorStrides>;
namespace fmt
{
/// @brief vector format for Eigen wrappers.
class V
{
public:
V(const TensorViewInfo& tensor_view_info)
: l0(tensor_view_info
.get_layout<
ngraph::descriptor::layout::DenseTensorViewLayout>()
->get_size())
{
}
V(size_t s)
: l0(s)
{
}
public:
size_t l0;
size_t l1{1};
size_t s0{1};
size_t s1{1};
};
class M
{
M(const Shape& shape, const Strides& strides)
: l0(shape.at(0))
, l1(shape.at(1))
, s0(strides.at(0))
, s1(strides.at(1))
{
}
M(const std::shared_ptr<ngraph::descriptor::layout::DenseTensorViewLayout>&
layout)
: M(layout->get_shape(), layout->get_strides())
{
}
public:
M(const TensorViewInfo& tensor_view_info)
: M(tensor_view_info.get_layout<
ngraph::descriptor::layout::DenseTensorViewLayout>())
{
}
public:
size_t l0;
size_t l1;
size_t s0;
size_t s1;
};
}
// ET element type
// FMT array format (fmt::V for vector, etc.)
// BASE select array/matrix
template <typename ET,
typename FMT,
typename BASE,
typename STRIDES = DynamicStrides>
class EigenWrapper : public BASE
{
using base = BASE;
public:
EigenWrapper(typename ET::type* t, const FMT& fmt)
: base(t, fmt.l0, fmt.l1, STRIDES(fmt.s0, fmt.s1))
{
}
EigenWrapper(
typename ET::type* t,
const std::shared_ptr<ngraph::descriptor::layout::DenseTensorViewLayout>&
layout)
: base(t, layout->get_size(), 1, DynamicStrides(1, 1))
{
}
EigenWrapper(CallFrame* call_frame, const TensorViewInfo& tensor_view_info)
: EigenWrapper(
call_frame->get_tensor_view_data<ET>(tensor_view_info.get_index()),
FMT(tensor_view_info))
{
}
template <typename U>
EigenWrapper& operator=(const U& other)
{
this->base::operator=(other);
return *this;
}
};
template <typename ET, typename FMT = fmt::V>
using EigenArray1d = EigenWrapper<ET, FMT, EigenArrayBase<ET>>;
template <typename ET, typename FMT = fmt::M>
using EigenArray2d = EigenWrapper<ET, FMT, EigenArrayBase<ET>>;
template <typename ET, typename FMT = fmt::M>
using EigenMatrix = EigenWrapper<ET, FMT, EigenMatrixBase<ET>>;
template <typename ET, typename FMT = fmt::V>
using EigenVector = EigenWrapper<ET, FMT, EigenVectorBase<ET>, VectorStrides>;
}
}
}
}
......@@ -14,8 +14,12 @@
#include <iostream>
#include <vector>
#include <typeindex>
#include <string>
#include <unordered_map>
#include "ngraph/node.hpp"
#include "ngraph/descriptor/layout/dense_tensor_view_layout.hpp"
#include "ngraph/runtime/tensor_view_info.hpp"
#include "ngraph/runtime/cpu/external_function.hpp"
#include "ngraph/runtime/cpu/emitter.hpp"
......@@ -23,11 +27,29 @@
using namespace std;
using namespace ngraph::runtime::cpu;
using ngraph::descriptor::layout::DenseTensorViewLayout;
#define TI(x) type_index(typeid(x))
static unordered_map<type_index, string> element_type_names = {{TI(ngraph::element::Bool), "Bool"},
{TI(ngraph::element::Float32), "Float32"},
{TI(ngraph::element::Int8), "Int8"},
{TI(ngraph::element::Int32), "Int32"},
{TI(ngraph::element::Int64), "Int64"},
{TI(ngraph::element::UInt8), "UInt8"},
{TI(ngraph::element::UInt32), "UInt32"},
{TI(ngraph::element::UInt64), "UInt64"}
};
#define EIGEN_VECTOR_FORMAT(x) "{" + to_string(x) + "}"
#define EIGEN_MATRIX_FORMAT(x)
void Emitter::EmitNop(const ngraph::Node* n,
ExternalFunction* ef,
FunctionMap& function_map,
const std::vector<TensorViewInfo>& inputs,
const std::vector<TensorViewInfo>& outputs) const
const std::vector<TensorViewInfo>& outputs)
{
}
......@@ -36,72 +58,52 @@ void Emitter::EmitAdd(const ngraph::Node* n,
ExternalFunction* ef,
FunctionMap& function_map,
const std::vector<TensorViewInfo>& inputs,
const std::vector<TensorViewInfo>& outputs) const
const std::vector<TensorViewInfo>& outputs)
{
const element::Type& et = (dynamic_pointer_cast<const TensorViewType>(
n->get_arguments().at(0)->get_value_type()))
->get_element_type();
TU += " {\n"
" auto arg0 = call_frame->get_tensor_view_data<" + element_type_names[TI(et)] + ">(" + to_string(inputs[0].get_index()) + ");\n"
" auto arg1 = call_frame->get_tensor_view_data<" + element_type_names[TI(et)] + ">(" + to_string(inputs[1].get_index()) + ");\n"
" auto out = call_frame->get_tensor_view_data<" + element_type_names[TI(et)] + ">(" + to_string(outputs[0].get_index()) + ");\n"
" EigenArray1d<" + element_type_names[TI(et)] + ">(out, "
EIGEN_VECTOR_FORMAT(outputs[0].get_layout<DenseTensorViewLayout>()->get_size()) ") =\n"
" EigenArray1d<" + element_type_names[TI(et)] + ">(arg0, "
EIGEN_VECTOR_FORMAT(inputs[0].get_layout<DenseTensorViewLayout>()->get_size()) ") +\n"
" EigenArray1d<" + element_type_names[TI(et)] + ">(arg1, "
EIGEN_VECTOR_FORMAT(inputs[1].get_layout<DenseTensorViewLayout>()->get_size()) ");\n"
" }\n";
}
void Emitter::EmitDot(const ngraph::Node* n,
ExternalFunction* ef,
FunctionMap& function_map,
const std::vector<TensorViewInfo>& inputs,
const std::vector<TensorViewInfo>& outputs) const
const std::vector<TensorViewInfo>& outputs)
{
auto& arg_nodes = n->get_arguments();
assert(arg_nodes.size() == 2);
auto arg0_tensor_type =
dynamic_pointer_cast<const TensorViewType>(arg_nodes.at(0)->get_value_type());
assert(nullptr != arg0_tensor_type);
auto arg1_tensor_type =
dynamic_pointer_cast<const TensorViewType>(arg_nodes.at(1)->get_value_type());
assert(nullptr != arg1_tensor_type);
auto arg0_shape = arg0_tensor_type->get_shape();
auto arg1_shape = arg1_tensor_type->get_shape();
auto& arg0_element_type = arg0_tensor_type->get_element_type();
// If arg0 or arg1 is a scalar, emit a scalar-tensor product.
if (arg0_shape.size() == 0)
{
cout << "Emitting scalar-tensor product\n";
}
else if (arg1_shape.size() == 0)
{
cout << "Emitting scalar-tensor product\n";
}
// If arg0 and arg1 are both vectors, emit a dot product.
else if (arg0_shape.size() == 1 && arg1_shape.size() == 1)
{
cout << "Emitting dot product\n";
}
// If arg0 is a matrix and arg1 is a vector, emit a matrix-vector product.
else if (arg0_shape.size() == 2 && arg1_shape.size() == 1)
{
cout << "Emitting matrix-vector product\n";
}
// If arg0 and arg1 are both matrices, emit a matrix product.
else if (arg0_shape.size() == 2 && arg1_shape.size() == 2)
{
cout << "Emitting matrix multiply\n";
}
else
{
throw ngraph_error("Dot product for tensors with rank>2 not implemented yet.");
}
}
void Emitter::EmitMultiply(const ngraph::Node* n,
ExternalFunction* ef,
FunctionMap& function_map,
const std::vector<TensorViewInfo>& inputs,
const std::vector<TensorViewInfo>& outputs) const
const std::vector<TensorViewInfo>& outputs)
{
const element::Type& et = (dynamic_pointer_cast<const TensorViewType>(
n->get_arguments().at(0)->get_value_type()))
->get_element_type();
TU += " {\n"
" auto arg0 = call_frame->get_tensor_view_data<" + element_type_names[TI(et)] + ">(" + to_string(inputs[0].get_index()) + ");\n"
" auto arg1 = call_frame->get_tensor_view_data<" + element_type_names[TI(et)] + ">(" + to_string(inputs[1].get_index()) + ");\n"
" auto out = call_frame->get_tensor_view_data<" + element_type_names[TI(et)] + ">(" + to_string(outputs[0].get_index()) + ");\n"
" EigenArray1d<" + element_type_names[TI(et)] + ">(out, "
EIGEN_VECTOR_FORMAT(outputs[0].get_layout<DenseTensorViewLayout>()->get_size()) ") =\n"
" EigenArray1d<" + element_type_names[TI(et)] + ">(arg0, "
EIGEN_VECTOR_FORMAT(inputs[0].get_layout<DenseTensorViewLayout>()->get_size()) ") *\n"
" EigenArray1d<" + element_type_names[TI(et)] + ">(arg1, "
EIGEN_VECTOR_FORMAT(inputs[1].get_layout<DenseTensorViewLayout>()->get_size()) ");\n"
" }\n";
}
......@@ -40,25 +40,25 @@ namespace ngraph
ExternalFunction*,
FunctionMap&,
const std::vector<TensorViewInfo>& inputs,
const std::vector<TensorViewInfo>& outputs) const;
const std::vector<TensorViewInfo>& outputs);
void EmitAdd(const ngraph::Node*,
ExternalFunction*,
FunctionMap&,
const std::vector<TensorViewInfo>& inputs,
const std::vector<TensorViewInfo>& outputs) const;
const std::vector<TensorViewInfo>& outputs);
void EmitDot(const ngraph::Node*,
ExternalFunction*,
FunctionMap&,
const std::vector<TensorViewInfo>& inputs,
const std::vector<TensorViewInfo>& outputs) const;
const std::vector<TensorViewInfo>& outputs);
void EmitMultiply(const ngraph::Node*,
ExternalFunction*,
FunctionMap&,
const std::vector<TensorViewInfo>& inputs,
const std::vector<TensorViewInfo>& outputs) const;
const std::vector<TensorViewInfo>& outputs);
};
}
......
......@@ -18,6 +18,7 @@
#include <typeinfo>
#include <unordered_map>
#include <string>
#include <fstream>
#include "ngraph/codegen/compiler.hpp"
#include "ngraph/descriptor/input.hpp"
......@@ -76,6 +77,7 @@ ExternalFunction::ExternalFunction(const std::shared_ptr<ngraph::Function>& func
bool release_function)
: ngraph::runtime::ExternalFunction(function, release_function)
, m_instructions(make_shared<std::vector<std::shared_ptr<Instruction>>>())
, compiled_function(nullptr)
{
}
......@@ -150,7 +152,7 @@ void ExternalFunction::compile(FunctionMap& function_map)
// Now we build the TU
Emitter emitter;
auto TU = emitter.GetTU();
auto& TU = emitter.GetTU();
TU += R"(
#include <vector>
#include <memory>
......@@ -158,11 +160,16 @@ void ExternalFunction::compile(FunctionMap& function_map)
#include <Eigen/Dense>
#include "ngraph/descriptor/layout/dense_tensor_view_layout.hpp"
#include "ngraph/runtime/tensor_view.hpp"
#include "ngraph/runtime/tensor_view_info.hpp"
#include "ngraph/runtime/cpu/call_frame.hpp"
#include "ngraph/runtime/cpu/eigen_utils.hpp"
void *__dso_handle = 0;
using namespace ngraph::element;
using namespace ngraph::runtime;
using namespace ngraph::runtime::cpu::eigen;
extern "C" void __entrypoint(ngraph::runtime::cpu::CallFrame* call_frame,
ngraph::runtime::TensorViewPtrs& tensor_views)
{
......@@ -196,12 +203,19 @@ extern "C" void __entrypoint(ngraph::runtime::cpu::CallFrame* call_frame,
// End TU
TU += "}\n";
// TODO: Cleanup and make this a utility function
ofstream out("__ngcpu_codegen.cpp");
out << TU;
out.close();
ngraph::codegen::execution_state estate;
auto llvm_module = estate.compile(TU, "ExternalFunction");
auto llvm_module = estate.compile(TU, "__ngcpu_codegen.cpp");
assert(llvm_module);
estate.add_module(llvm_module);
estate.finalize();
//auto llvm_func = estate.find_function
compiled_function = estate.find_function<void(ngraph::runtime::cpu::CallFrame*,
ngraph::runtime::TensorViewPtrs&)>("__entrypoint");
assert(compiled_function);
m_is_compiled = true;
if (m_release_function)
......@@ -278,5 +292,5 @@ shared_ptr<ngraph::runtime::CallFrame> ExternalFunction::make_call_frame()
#undef M
}
return make_shared<ngraph::runtime::cpu::CallFrame>(
m_n_inputs, m_n_outputs, temps);
compiled_function, m_n_inputs, m_n_outputs, temps);
}
......@@ -18,6 +18,7 @@
#include <typeindex>
#include <typeinfo>
#include <unordered_map>
#include <functional>
#include "ngraph/function.hpp"
#include "ngraph/codegen/compiler.hpp"
......@@ -38,7 +39,7 @@ namespace ngraph
using FunctionMap = std::unordered_map<std::shared_ptr<Function>,
std::shared_ptr<ExternalFunction>>;
using OpFunction = std::function<void(const Emitter*,
using OpFunction = std::function<void(Emitter*,
const ngraph::Node*,
ExternalFunction*,
FunctionMap&,
......@@ -62,6 +63,7 @@ namespace ngraph
size_t m_n_outputs;
std::shared_ptr<std::vector<std::shared_ptr<Instruction>>> m_instructions;
ngraph::descriptor::TensorViewPtrs m_temp_views;
EntryPoint compiled_function;
};
}
}
......
......@@ -27,7 +27,7 @@ namespace ngraph
{
public:
TensorViewInfo(size_t index,
const std::shared_ptr<ngraph::descriptor::TensorView>& descriptor)
const std::shared_ptr<const ngraph::descriptor::TensorView>& descriptor)
: m_index(index)
, m_layout(descriptor->get_tensor_view_layout())
{
......
......@@ -21,6 +21,33 @@
using namespace std;
using namespace ngraph;
TEST(cpu, ab)
{
auto shape = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto B = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto rt = make_shared<TensorViewType>(element::Float32::element_type(), shape);
auto f = make_shared<Function>(A + B, rt, op::Parameters{A, B});
auto manager = runtime::Manager::get("CPU");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
// Create some tensors for input/output
auto a = backend->make_parameterized_tensor_view<element::Float32>(shape);
*a = vector<float>{1, 2, 3, 4};
auto b = backend->make_parameterized_tensor_view<element::Float32>(shape);
*b = vector<float>{5, 6, 7, 8};
auto result = backend->make_parameterized_tensor_view<element::Float32>(shape);
(*cf)({a, b}, {result});
ASSERT_EQ((vector<float>{6, 8, 10, 12}), result->get_vector());
(*cf)({b, a}, {result});
ASSERT_EQ((vector<float>{6, 8, 10, 12}), result->get_vector());
}
TEST(cpu, abc)
{
auto shape = Shape{2, 2};
......@@ -55,7 +82,7 @@ TEST(cpu, abc)
}
/*
TEST(execute, abc_int64)
TEST(cpu, abc_int64)
{
auto shape = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::Int64::element_type(), shape);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment