Commit 0b1d09ae authored by Scott Cyphers's avatar Scott Cyphers

Type construction, simple value descriptors

parent 8c16125d
......@@ -15,8 +15,13 @@ public:
class TensorDescriptor
{
public:
TensorDescriptor(const ElementType& element_type)
: m_element_type(element_type)
{}
protected:
ElementType element_type;
const ElementType& m_element_type;
};
class TensorLayoutDescriptor
......@@ -27,24 +32,54 @@ class TensorLayoutDescriptor
class TensorViewDescriptor : public ValueDescriptor
{
public:
TensorViewDescriptor(const std::shared_ptr<TensorViewType>& type)
: m_type(type)
{}
TensorViewDescriptor(const ElementType& element_type, const std::vector<value_size_t>& shape)
: TensorViewDescriptor(TensorViewType::make_shared(element_type, shape))
{}
static std::shared_ptr<TensorViewDescriptor> make_shared(const std::shared_ptr<TensorViewType>& type){
return std::shared_ptr<TensorViewDescriptor>(new TensorViewDescriptor(type));
}
static std::shared_ptr<TensorViewDescriptor> make_shared(const ElementType& element_type, const std::vector<value_size_t>& shape){
return std::shared_ptr<TensorViewDescriptor>(new TensorViewDescriptor(element_type, shape));
}
std::shared_ptr<ValueType> value_type() const override {
return m_type;
}
protected:
std::shared_ptr<TensorViewType> m_type;
TensorDescriptor m_tensor_descriptor;
TensorLayoutDescriptor m_tensor_layout_descriptor;
std::shared_ptr<TensorDescriptor> m_tensor_descriptor;
std::shared_ptr<TensorLayoutDescriptor> m_tensor_layout_descriptor;
};
class TupleDescriptor : public ValueDescriptor
{
public:
TupleDescriptor(const std::vector<std::shared_ptr<ValueDescriptor>>& elements)
: m_element_descriptors(elements)
{
std::vector<std::shared_ptr<ValueType>> types;
for(auto elt : elements){
types.push_back(elt->value_type());
}
m_type = TupleType::make_shared(types);
}
static std::shared_ptr<TupleDescriptor> make_shared(const std::vector<std::shared_ptr<ValueDescriptor>>& elements){
return std::shared_ptr<TupleDescriptor>(new TupleDescriptor(elements));
}
std::shared_ptr<ValueType> value_type() const override {
return m_type;
}
protected:
std::shared_ptr<TupleType> m_type;
std::vector<ValueDescriptor> m_element_descriptors;
std::vector<std::shared_ptr<ValueDescriptor>> m_element_descriptors;
};
} // End of NGRAPH
#pragma once
#include "values/descriptors.hpp"
#include "values/types.hpp"
namespace ngraph {
class Function
{
std::vector<std::shared_ptr<ValueDescriptor>> m_arguments;
std::shared_ptr<ValueDescriptor> m_result;
};
} // end namespace ngraph
\ No newline at end of file
......@@ -17,16 +17,35 @@ class ValueType
class TensorViewType : public ValueType
{
public:
TensorViewType(const ElementType& element_type, const std::vector<value_size_t>& shape)
: m_element_type(element_type)
, m_shape(shape)
{}
static std::shared_ptr<TensorViewType> make_shared(const ElementType& element_type, const std::vector<value_size_t>& shape){
return std::shared_ptr<TensorViewType>(new TensorViewType(element_type, shape));
}
protected:
ElementType m_element_type;
TensorViewType(const TensorViewType&) = delete;
const ElementType& m_element_type;
std::vector<value_size_t> m_shape;
};
class TupleType : public ValueType
{
public:
TupleType(const std::vector<std::shared_ptr<ValueType>>& element_types)
: m_element_types(element_types)
{}
static std::shared_ptr<TupleType> make_shared(const std::vector<std::shared_ptr<ValueType>>& element_types){
return std::shared_ptr<TupleType>(new TupleType(element_types));
}
protected:
// Is this name too similar to TensorViewType.to m_element_type?
std::vector<ValueType> m_element_types;
std::vector<std::shared_ptr<ValueType>> m_element_types;
};
} // End of ngraph
\ No newline at end of file
......@@ -5,4 +5,9 @@ using namespace ngraph;
void build_simple_graph()
{
}
\ No newline at end of file
auto tv_a = TensorViewType::make_shared(element_type_float, {2,3, 5});
auto tp_b = TupleType::make_shared({tv_a, tv_a});
auto tp_c = TupleType::make_shared({tp_b, tv_a});
auto tensor_d = TensorViewDescriptor::make_shared(element_type_float, {2, 3, 5});
auto tuple_d = TupleDescriptor::make_shared({tensor_d, tensor_d});
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment