Commit 28f13818 authored by Scott Cyphers's avatar Scott Cyphers

Renamings, use Shape

parent eb2a4980
......@@ -4,7 +4,7 @@
#include <memory>
#include <vector>
#include "values/types.hpp"
#include "values/type.hpp"
namespace ngraph {
......@@ -44,16 +44,16 @@ public:
: m_type(type)
{}
TensorViewDescriptor(const ElementType& element_type, const std::vector<value_size_t>& shape)
TensorViewDescriptor(const ElementType& element_type, const Shape& shape)
: TensorViewDescriptor(TensorViewType::make(element_type, shape))
{}
static ptr_t make(const TensorViewType::ptr_t& type){
return ptr_t(new TensorViewDescriptor(type));
return ptr_t::make_shared(type);
}
static ptr_t make(const ElementType& element_type, const std::vector<value_size_t>& shape){
return ptr_t(new TensorViewDescriptor(element_type, shape));
static ptr_t make(const ElementType& element_type, const Shape& shape){
return ptr_t::make_shared(element_type, shape);
}
ValueType::ptr_t value_type() const override {
......@@ -81,7 +81,7 @@ public:
}
static ptr_t make(const std::vector<ValueDescriptor::ptr_t>& elements){
return ptr_t(new TupleDescriptor(elements));
return ptr_t::make_shared(elements);
}
ValueType::ptr_t value_type() const override {
......
......@@ -4,10 +4,10 @@ using namespace std;
using namespace ngraph;
Parameter::ptr_t Parameter::make(Function& function, size_t index, const ValueType::ptr_t& output_type){
return ptr_t(new Parameter(function, index, output_type));
return ptr_t::make_shared(function, index, output_type);
}
Function::ptr_t Function::make(const ValueType::ptr_t& return_type, const std::vector<ValueType::ptr_t>& argument_types){
return ptr_t(new Function(return_type, argument_types));
return ptr_t::make_shared(return_type, argument_types);
}
#pragma once
#include "values/descriptors.hpp"
#include "values/descriptor.hpp"
#include "values/op.hpp"
#include "values/types.hpp"
#include "values/type.hpp"
namespace ngraph {
......@@ -14,13 +14,14 @@ public:
using ptr_t = std::shared_ptr<Parameter>;
static ptr_t make(Function& function, size_t index, const ValueType::ptr_t& output_type);
protected:
Parameter(Function& function, size_t index, const ValueType::ptr_t& output_type)
: Op({}, output_type)
, m_function(function)
, m_index(index)
{}
protected:
Function& m_function;
size_t m_index;
};
......@@ -30,7 +31,6 @@ class Function
public:
using ptr_t = std::shared_ptr<Function>;
protected:
Function(const ValueType::ptr_t& return_type,
const std::vector<ValueType::ptr_t>& argument_types)
: m_return_type(return_type)
......@@ -42,7 +42,6 @@ protected:
}
}
public:
static ptr_t make(const ValueType::ptr_t& return_type,
const std::vector<ValueType::ptr_t>& argument_types);
......
......@@ -2,8 +2,8 @@
#include <memory>
#include "values/descriptors.hpp"
#include "values/types.hpp"
#include "values/descriptor.hpp"
#include "values/type.hpp"
namespace ngraph {
......@@ -27,7 +27,6 @@ class Broadcast : public Op
public:
using ptr_t = std::shared_ptr<Broadcast>;
protected:
Broadcast(const Op::ptr_t& x, std::vector<size_t> dims)
: Op({x}, 0)
, m_dims(dims)
......@@ -35,11 +34,20 @@ protected:
public:
static ptr_t make(const Op::ptr_t& x, std::vector<size_t> dims){
return ptr_t(new Broadcast(x, dims));
return ptr_t::make_shared(x, dims);
}
protected:
std::vector<size_t> m_dims;
};
class Tuple : public Op
{
public:
Tuple(const std::vector<ptr_t>& inputs)
: Op(inputs, 0)
{
}
};
} // end of namespace ngraph
\ No newline at end of file
......@@ -12,6 +12,17 @@ class TupleDescriptor;
using value_size_t = size_t;
class Shape
{
public:
Shape(const std::initializer_list<value_size_t>& sizes)
: m_sizes(sizes)
{}
protected:
std::vector<value_size_t> m_sizes;
};
// Base type for ngraph values
class ValueType
{
......@@ -25,18 +36,19 @@ public:
using ptr_t = std::shared_ptr<TensorViewType>;
using descriptor_t = TensorViewDescriptor;
TensorViewType(const ElementType& element_type, const std::vector<value_size_t>& shape)
TensorViewType(const ElementType& element_type, const Shape& shape)
: m_element_type(element_type)
, m_shape(shape)
{}
static ptr_t make(const ElementType& element_type, const std::vector<value_size_t>& shape){
return ptr_t(new TensorViewType(element_type, shape));
static ptr_t make(const ElementType& element_type, const Shape& shape){
return ptr_t::make_shared(element_type, shape);
}
protected:
TensorViewType(const TensorViewType&) = delete;
const ElementType& m_element_type;
std::vector<value_size_t> m_shape;
Shape m_shape;
};
class TupleType : public ValueType
......@@ -50,7 +62,7 @@ public:
{}
static ptr_t make(const std::vector<ValueType::ptr_t>& element_types){
return ptr_t(new TupleType(element_types));
return ptr_t::make_shared(element_types);
}
protected:
......
#include "values/descriptors.hpp"
#include "values/descriptor.hpp"
#include "values/function.hpp"
using namespace std;
......@@ -6,17 +6,15 @@ using namespace ngraph;
void build_simple_graph()
{
auto tv_a = TensorViewType::make(element_type_float, {2, 3, 5});
auto tp_b = TupleType::make({tv_a, tv_a});
auto tp_c = TupleType::make({tp_b, tv_a});
auto tensor_d = TensorViewDescriptor::make(element_type_float, {2, 3, 5});
auto tuple_d = TupleDescriptor::make({tensor_d, tensor_d});
auto cluster_0 = Function::make(
TensorViewType::make(element_type_float, {32, 3}),
{TensorViewType::make(element_type_float, {7, 3}),
TensorViewType::make(element_type_float, {3}),
TensorViewType::make(element_type_float, {32, 7}),
TensorViewType::make(element_type_float, {32, 7})
TensorViewType::make(element_type_float, Shape({32, 3})),
{TensorViewType::make(element_type_float, Shape({7, 3})),
TensorViewType::make(element_type_float, Shape({3})),
TensorViewType::make(element_type_float, Shape({32, 7})),
TensorViewType::make(element_type_float, Shape({32, 7}))
});
auto arg3 = cluster_0->parameter(3);
auto broadcast_1 = Broadcast::make(arg3, {1});
auto arg2 = cluster_0->parameter(2);
auto arg0 = cluster_0->parameter(0);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment