Commit eb2a4980 authored by Scott Cyphers's avatar Scott Cyphers

More plumbing

parent 2504daa1
......@@ -27,6 +27,8 @@ set (SRC
transformers/mock_transformer.cpp
transformers/ndarray.cpp
transformers/op_graph.cpp
values/function.cpp
)
# NOTE: We'd prefer to only have the .cpp files *in* the 'transformers' directory be compiled
......
#include "values/function.hpp"
using namespace std;
using namespace ngraph;
Parameter::ptr_t Parameter::make(Function& function, size_t index, const ValueType::ptr_t& output_type){
return ptr_t(new Parameter(function, index, output_type));
}
Function::ptr_t Function::make(const ValueType::ptr_t& return_type, const std::vector<ValueType::ptr_t>& argument_types){
return ptr_t(new Function(return_type, argument_types));
}
#pragma once
#include "values/descriptors.hpp"
#include "values/op.hpp"
#include "values/types.hpp"
namespace ngraph {
class Function;
class Parameter : public Op
{
public:
using ptr_t = std::shared_ptr<Parameter>;
static ptr_t make(Function& function, size_t index, const ValueType::ptr_t& output_type);
protected:
Parameter(Function& function, size_t index, const ValueType::ptr_t& output_type)
: Op({}, output_type)
, m_function(function)
, m_index(index)
{}
Function& m_function;
size_t m_index;
};
class Function
{
public:
using ptr_t = std::shared_ptr<Function>;
Function(const ValueType::ptr_t& return_type, const std::vector<ValueType::ptr_t>& argument_types)
protected:
Function(const ValueType::ptr_t& return_type,
const std::vector<ValueType::ptr_t>& argument_types)
: m_return_type(return_type)
, m_argument_types(argument_types)
{}
{
size_t i = 0;
for (auto argument_type : argument_types){
m_parameters.push_back(Parameter::make(*this, i++, argument_type));
}
}
public:
static ptr_t make(const ValueType::ptr_t& return_type,
const std::vector<ValueType::ptr_t>& argument_types);
static ptr_t make(const ValueType::ptr_t& return_type, const std::vector<ValueType::ptr_t>& argument_types){
return ptr_t(new Function(return_type, argument_types));
Parameter::ptr_t parameter(size_t i){
return m_parameters[i];
}
protected:
std::vector<Parameter::ptr_t> m_parameters;
std::vector<std::shared_ptr<ValueType>> m_argument_types;
std::shared_ptr<ValueType> m_return_type;
};
......
#pragma once
#include <memory>
#include "values/descriptors.hpp"
#include "values/types.hpp"
namespace ngraph {
class Op
{
public:
using ptr_t = std::shared_ptr<Op>;
protected:
Op(const std::vector<ptr_t>& inputs, const ValueType::ptr_t output_type)
: m_inputs(inputs)
, m_output_type(output_type)
{}
std::vector<ptr_t> m_inputs;
ValueType::ptr_t m_output_type;
};
class Broadcast : public Op
{
public:
using ptr_t = std::shared_ptr<Broadcast>;
protected:
Broadcast(const Op::ptr_t& x, std::vector<size_t> dims)
: Op({x}, 0)
, m_dims(dims)
{}
public:
static ptr_t make(const Op::ptr_t& x, std::vector<size_t> dims){
return ptr_t(new Broadcast(x, dims));
}
protected:
std::vector<size_t> m_dims;
};
} // end of namespace ngraph
\ No newline at end of file
......@@ -7,6 +7,9 @@
namespace ngraph {
class TensorViewDescriptor;
class TupleDescriptor;
using value_size_t = size_t;
// Base type for ngraph values
......@@ -20,6 +23,7 @@ class TensorViewType : public ValueType
{
public:
using ptr_t = std::shared_ptr<TensorViewType>;
using descriptor_t = TensorViewDescriptor;
TensorViewType(const ElementType& element_type, const std::vector<value_size_t>& shape)
: m_element_type(element_type)
......@@ -39,6 +43,7 @@ class TupleType : public ValueType
{
public:
using ptr_t = std::shared_ptr<TupleType>;
using descriptor_t = TupleDescriptor;
TupleType(const std::vector<ValueType::ptr_t>& element_types)
: m_element_types(element_types)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment