Commit 2504daa1 authored by Scott Cyphers's avatar Scott Cyphers

Add Function, cleanups.

parent 0b1d09ae
#pragma once #pragma once
#include <algorithm>
#include <memory> #include <memory>
#include <vector> #include <vector>
...@@ -10,12 +11,16 @@ namespace ngraph { ...@@ -10,12 +11,16 @@ namespace ngraph {
class ValueDescriptor class ValueDescriptor
{ {
public: public:
virtual std::shared_ptr<ValueType> value_type() const = 0; using ptr_t = std::shared_ptr<ValueDescriptor>;
virtual ValueType::ptr_t value_type() const = 0;
}; };
class TensorDescriptor class TensorDescriptor
{ {
public: public:
using ptr_t = std::shared_ptr<TensorDescriptor>;
TensorDescriptor(const ElementType& element_type) TensorDescriptor(const ElementType& element_type)
: m_element_type(element_type) : m_element_type(element_type)
{} {}
...@@ -26,60 +31,65 @@ protected: ...@@ -26,60 +31,65 @@ protected:
class TensorLayoutDescriptor class TensorLayoutDescriptor
{ {
public:
using ptr_t = std::shared_ptr<TensorLayoutDescriptor>;
}; };
class TensorViewDescriptor : public ValueDescriptor class TensorViewDescriptor : public ValueDescriptor
{ {
public: public:
TensorViewDescriptor(const std::shared_ptr<TensorViewType>& type) using ptr_t = std::shared_ptr<TensorViewDescriptor>;
TensorViewDescriptor(const TensorViewType::ptr_t& type)
: m_type(type) : m_type(type)
{} {}
TensorViewDescriptor(const ElementType& element_type, const std::vector<value_size_t>& shape) TensorViewDescriptor(const ElementType& element_type, const std::vector<value_size_t>& shape)
: TensorViewDescriptor(TensorViewType::make_shared(element_type, shape)) : TensorViewDescriptor(TensorViewType::make(element_type, shape))
{} {}
static std::shared_ptr<TensorViewDescriptor> make_shared(const std::shared_ptr<TensorViewType>& type){ static ptr_t make(const TensorViewType::ptr_t& type){
return std::shared_ptr<TensorViewDescriptor>(new TensorViewDescriptor(type)); return ptr_t(new TensorViewDescriptor(type));
} }
static std::shared_ptr<TensorViewDescriptor> make_shared(const ElementType& element_type, const std::vector<value_size_t>& shape){ static ptr_t make(const ElementType& element_type, const std::vector<value_size_t>& shape){
return std::shared_ptr<TensorViewDescriptor>(new TensorViewDescriptor(element_type, shape)); return ptr_t(new TensorViewDescriptor(element_type, shape));
} }
std::shared_ptr<ValueType> value_type() const override { ValueType::ptr_t value_type() const override {
return m_type; return m_type;
} }
protected: protected:
std::shared_ptr<TensorViewType> m_type; TensorViewType::ptr_t m_type;
std::shared_ptr<TensorDescriptor> m_tensor_descriptor; TensorDescriptor::ptr_t m_tensor_descriptor;
std::shared_ptr<TensorLayoutDescriptor> m_tensor_layout_descriptor; TensorLayoutDescriptor::ptr_t m_tensor_layout_descriptor;
}; };
class TupleDescriptor : public ValueDescriptor class TupleDescriptor : public ValueDescriptor
{ {
public: public:
TupleDescriptor(const std::vector<std::shared_ptr<ValueDescriptor>>& elements) using ptr_t = std::shared_ptr<TupleDescriptor>;
TupleDescriptor(const std::vector<ValueDescriptor::ptr_t>& elements)
: m_element_descriptors(elements) : m_element_descriptors(elements)
{ {
std::vector<std::shared_ptr<ValueType>> types; std::vector<ValueType::ptr_t> types;
for(auto elt : elements){ for(auto elt : elements){
types.push_back(elt->value_type()); types.push_back(elt->value_type());
} }
m_type = TupleType::make_shared(types); m_type = TupleType::make(types);
} }
static std::shared_ptr<TupleDescriptor> make_shared(const std::vector<std::shared_ptr<ValueDescriptor>>& elements){ static ptr_t make(const std::vector<ValueDescriptor::ptr_t>& elements){
return std::shared_ptr<TupleDescriptor>(new TupleDescriptor(elements)); return ptr_t(new TupleDescriptor(elements));
} }
std::shared_ptr<ValueType> value_type() const override { ValueType::ptr_t value_type() const override {
return m_type; return m_type;
} }
protected: protected:
std::shared_ptr<TupleType> m_type; TupleType::ptr_t m_type;
std::vector<std::shared_ptr<ValueDescriptor>> m_element_descriptors; std::vector<ValueDescriptor::ptr_t> m_element_descriptors;
}; };
} // End of NGRAPH } // End of NGRAPH
...@@ -7,8 +7,21 @@ namespace ngraph { ...@@ -7,8 +7,21 @@ namespace ngraph {
class Function class Function
{ {
std::vector<std::shared_ptr<ValueDescriptor>> m_arguments; public:
std::shared_ptr<ValueDescriptor> m_result; using ptr_t = std::shared_ptr<Function>;
Function(const ValueType::ptr_t& return_type, const std::vector<ValueType::ptr_t>& argument_types)
: m_return_type(return_type)
, m_argument_types(argument_types)
{}
static ptr_t make(const ValueType::ptr_t& return_type, const std::vector<ValueType::ptr_t>& argument_types){
return ptr_t(new Function(return_type, argument_types));
}
protected:
std::vector<std::shared_ptr<ValueType>> m_argument_types;
std::shared_ptr<ValueType> m_return_type;
}; };
......
...@@ -12,19 +12,22 @@ using value_size_t = size_t; ...@@ -12,19 +12,22 @@ using value_size_t = size_t;
// Base type for ngraph values // Base type for ngraph values
class ValueType class ValueType
{ {
public:
using ptr_t = std::shared_ptr<ValueType>;
}; };
class TensorViewType : public ValueType class TensorViewType : public ValueType
{ {
public: public:
using ptr_t = std::shared_ptr<TensorViewType>;
TensorViewType(const ElementType& element_type, const std::vector<value_size_t>& shape) TensorViewType(const ElementType& element_type, const std::vector<value_size_t>& shape)
: m_element_type(element_type) : m_element_type(element_type)
, m_shape(shape) , m_shape(shape)
{} {}
static std::shared_ptr<TensorViewType> make_shared(const ElementType& element_type, const std::vector<value_size_t>& shape){ static ptr_t make(const ElementType& element_type, const std::vector<value_size_t>& shape){
return std::shared_ptr<TensorViewType>(new TensorViewType(element_type, shape)); return ptr_t(new TensorViewType(element_type, shape));
} }
protected: protected:
TensorViewType(const TensorViewType&) = delete; TensorViewType(const TensorViewType&) = delete;
...@@ -35,17 +38,19 @@ protected: ...@@ -35,17 +38,19 @@ protected:
class TupleType : public ValueType class TupleType : public ValueType
{ {
public: public:
TupleType(const std::vector<std::shared_ptr<ValueType>>& element_types) using ptr_t = std::shared_ptr<TupleType>;
TupleType(const std::vector<ValueType::ptr_t>& element_types)
: m_element_types(element_types) : m_element_types(element_types)
{} {}
static std::shared_ptr<TupleType> make_shared(const std::vector<std::shared_ptr<ValueType>>& element_types){ static ptr_t make(const std::vector<ValueType::ptr_t>& element_types){
return std::shared_ptr<TupleType>(new TupleType(element_types)); return ptr_t(new TupleType(element_types));
} }
protected: protected:
// Is this name too similar to TensorViewType.to m_element_type? // Is this name too similar to TensorViewType.to m_element_type?
std::vector<std::shared_ptr<ValueType>> m_element_types; std::vector<ValueType::ptr_t> m_element_types;
}; };
} // End of ngraph } // End of ngraph
\ No newline at end of file
#include "values/descriptors.hpp" #include "values/descriptors.hpp"
#include "values/function.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
void build_simple_graph() void build_simple_graph()
{ {
auto tv_a = TensorViewType::make_shared(element_type_float, {2,3, 5}); auto tv_a = TensorViewType::make(element_type_float, {2, 3, 5});
auto tp_b = TupleType::make_shared({tv_a, tv_a}); auto tp_b = TupleType::make({tv_a, tv_a});
auto tp_c = TupleType::make_shared({tp_b, tv_a}); auto tp_c = TupleType::make({tp_b, tv_a});
auto tensor_d = TensorViewDescriptor::make_shared(element_type_float, {2, 3, 5}); auto tensor_d = TensorViewDescriptor::make(element_type_float, {2, 3, 5});
auto tuple_d = TupleDescriptor::make_shared({tensor_d, tensor_d}); auto tuple_d = TupleDescriptor::make({tensor_d, tensor_d});
auto cluster_0 = Function::make(
TensorViewType::make(element_type_float, {32, 3}),
{TensorViewType::make(element_type_float, {7, 3}),
TensorViewType::make(element_type_float, {3}),
TensorViewType::make(element_type_float, {32, 7}),
TensorViewType::make(element_type_float, {32, 7})
});
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment