Commit 28f13818 authored by Scott Cyphers's avatar Scott Cyphers

Renamings, use Shape

parent eb2a4980
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "values/types.hpp" #include "values/type.hpp"
namespace ngraph { namespace ngraph {
...@@ -44,16 +44,16 @@ public: ...@@ -44,16 +44,16 @@ public:
: m_type(type) : m_type(type)
{} {}
TensorViewDescriptor(const ElementType& element_type, const std::vector<value_size_t>& shape) TensorViewDescriptor(const ElementType& element_type, const Shape& shape)
: TensorViewDescriptor(TensorViewType::make(element_type, shape)) : TensorViewDescriptor(TensorViewType::make(element_type, shape))
{} {}
static ptr_t make(const TensorViewType::ptr_t& type){ static ptr_t make(const TensorViewType::ptr_t& type){
return ptr_t(new TensorViewDescriptor(type)); return ptr_t::make_shared(type);
} }
static ptr_t make(const ElementType& element_type, const std::vector<value_size_t>& shape){ static ptr_t make(const ElementType& element_type, const Shape& shape){
return ptr_t(new TensorViewDescriptor(element_type, shape)); return ptr_t::make_shared(element_type, shape);
} }
ValueType::ptr_t value_type() const override { ValueType::ptr_t value_type() const override {
...@@ -81,7 +81,7 @@ public: ...@@ -81,7 +81,7 @@ public:
} }
static ptr_t make(const std::vector<ValueDescriptor::ptr_t>& elements){ static ptr_t make(const std::vector<ValueDescriptor::ptr_t>& elements){
return ptr_t(new TupleDescriptor(elements)); return ptr_t::make_shared(elements);
} }
ValueType::ptr_t value_type() const override { ValueType::ptr_t value_type() const override {
......
...@@ -4,10 +4,10 @@ using namespace std; ...@@ -4,10 +4,10 @@ using namespace std;
using namespace ngraph; using namespace ngraph;
Parameter::ptr_t Parameter::make(Function& function, size_t index, const ValueType::ptr_t& output_type){ Parameter::ptr_t Parameter::make(Function& function, size_t index, const ValueType::ptr_t& output_type){
return ptr_t(new Parameter(function, index, output_type)); return ptr_t::make_shared(function, index, output_type);
} }
Function::ptr_t Function::make(const ValueType::ptr_t& return_type, const std::vector<ValueType::ptr_t>& argument_types){ Function::ptr_t Function::make(const ValueType::ptr_t& return_type, const std::vector<ValueType::ptr_t>& argument_types){
return ptr_t(new Function(return_type, argument_types)); return ptr_t::make_shared(return_type, argument_types);
} }
#pragma once #pragma once
#include "values/descriptors.hpp" #include "values/descriptor.hpp"
#include "values/op.hpp" #include "values/op.hpp"
#include "values/types.hpp" #include "values/type.hpp"
namespace ngraph { namespace ngraph {
...@@ -14,13 +14,14 @@ public: ...@@ -14,13 +14,14 @@ public:
using ptr_t = std::shared_ptr<Parameter>; using ptr_t = std::shared_ptr<Parameter>;
static ptr_t make(Function& function, size_t index, const ValueType::ptr_t& output_type); static ptr_t make(Function& function, size_t index, const ValueType::ptr_t& output_type);
protected:
Parameter(Function& function, size_t index, const ValueType::ptr_t& output_type) Parameter(Function& function, size_t index, const ValueType::ptr_t& output_type)
: Op({}, output_type) : Op({}, output_type)
, m_function(function) , m_function(function)
, m_index(index) , m_index(index)
{} {}
protected:
Function& m_function; Function& m_function;
size_t m_index; size_t m_index;
}; };
...@@ -30,7 +31,6 @@ class Function ...@@ -30,7 +31,6 @@ class Function
public: public:
using ptr_t = std::shared_ptr<Function>; using ptr_t = std::shared_ptr<Function>;
protected:
Function(const ValueType::ptr_t& return_type, Function(const ValueType::ptr_t& return_type,
const std::vector<ValueType::ptr_t>& argument_types) const std::vector<ValueType::ptr_t>& argument_types)
: m_return_type(return_type) : m_return_type(return_type)
...@@ -42,7 +42,6 @@ protected: ...@@ -42,7 +42,6 @@ protected:
} }
} }
public:
static ptr_t make(const ValueType::ptr_t& return_type, static ptr_t make(const ValueType::ptr_t& return_type,
const std::vector<ValueType::ptr_t>& argument_types); const std::vector<ValueType::ptr_t>& argument_types);
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
#include <memory> #include <memory>
#include "values/descriptors.hpp" #include "values/descriptor.hpp"
#include "values/types.hpp" #include "values/type.hpp"
namespace ngraph { namespace ngraph {
...@@ -27,7 +27,6 @@ class Broadcast : public Op ...@@ -27,7 +27,6 @@ class Broadcast : public Op
public: public:
using ptr_t = std::shared_ptr<Broadcast>; using ptr_t = std::shared_ptr<Broadcast>;
protected:
Broadcast(const Op::ptr_t& x, std::vector<size_t> dims) Broadcast(const Op::ptr_t& x, std::vector<size_t> dims)
: Op({x}, 0) : Op({x}, 0)
, m_dims(dims) , m_dims(dims)
...@@ -35,11 +34,20 @@ protected: ...@@ -35,11 +34,20 @@ protected:
public: public:
static ptr_t make(const Op::ptr_t& x, std::vector<size_t> dims){ static ptr_t make(const Op::ptr_t& x, std::vector<size_t> dims){
return ptr_t(new Broadcast(x, dims)); return ptr_t::make_shared(x, dims);
} }
protected: protected:
std::vector<size_t> m_dims; std::vector<size_t> m_dims;
}; };
class Tuple : public Op
{
public:
Tuple(const std::vector<ptr_t>& inputs)
: Op(inputs, 0)
{
}
};
} // end of namespace ngraph } // end of namespace ngraph
\ No newline at end of file
...@@ -12,6 +12,17 @@ class TupleDescriptor; ...@@ -12,6 +12,17 @@ class TupleDescriptor;
using value_size_t = size_t; using value_size_t = size_t;
class Shape
{
public:
Shape(const std::initializer_list<value_size_t>& sizes)
: m_sizes(sizes)
{}
protected:
std::vector<value_size_t> m_sizes;
};
// Base type for ngraph values // Base type for ngraph values
class ValueType class ValueType
{ {
...@@ -25,18 +36,19 @@ public: ...@@ -25,18 +36,19 @@ public:
using ptr_t = std::shared_ptr<TensorViewType>; using ptr_t = std::shared_ptr<TensorViewType>;
using descriptor_t = TensorViewDescriptor; using descriptor_t = TensorViewDescriptor;
TensorViewType(const ElementType& element_type, const std::vector<value_size_t>& shape) TensorViewType(const ElementType& element_type, const Shape& shape)
: m_element_type(element_type) : m_element_type(element_type)
, m_shape(shape) , m_shape(shape)
{} {}
static ptr_t make(const ElementType& element_type, const std::vector<value_size_t>& shape){ static ptr_t make(const ElementType& element_type, const Shape& shape){
return ptr_t(new TensorViewType(element_type, shape)); return ptr_t::make_shared(element_type, shape);
} }
protected: protected:
TensorViewType(const TensorViewType&) = delete; TensorViewType(const TensorViewType&) = delete;
const ElementType& m_element_type; const ElementType& m_element_type;
std::vector<value_size_t> m_shape; Shape m_shape;
}; };
class TupleType : public ValueType class TupleType : public ValueType
...@@ -50,7 +62,7 @@ public: ...@@ -50,7 +62,7 @@ public:
{} {}
static ptr_t make(const std::vector<ValueType::ptr_t>& element_types){ static ptr_t make(const std::vector<ValueType::ptr_t>& element_types){
return ptr_t(new TupleType(element_types)); return ptr_t::make_shared(element_types);
} }
protected: protected:
......
#include "values/descriptors.hpp" #include "values/descriptor.hpp"
#include "values/function.hpp" #include "values/function.hpp"
using namespace std; using namespace std;
...@@ -6,17 +6,15 @@ using namespace ngraph; ...@@ -6,17 +6,15 @@ using namespace ngraph;
void build_simple_graph() void build_simple_graph()
{ {
auto tv_a = TensorViewType::make(element_type_float, {2, 3, 5});
auto tp_b = TupleType::make({tv_a, tv_a});
auto tp_c = TupleType::make({tp_b, tv_a});
auto tensor_d = TensorViewDescriptor::make(element_type_float, {2, 3, 5});
auto tuple_d = TupleDescriptor::make({tensor_d, tensor_d});
auto cluster_0 = Function::make( auto cluster_0 = Function::make(
TensorViewType::make(element_type_float, {32, 3}), TensorViewType::make(element_type_float, Shape({32, 3})),
{TensorViewType::make(element_type_float, {7, 3}), {TensorViewType::make(element_type_float, Shape({7, 3})),
TensorViewType::make(element_type_float, {3}), TensorViewType::make(element_type_float, Shape({3})),
TensorViewType::make(element_type_float, {32, 7}), TensorViewType::make(element_type_float, Shape({32, 7})),
TensorViewType::make(element_type_float, {32, 7}) TensorViewType::make(element_type_float, Shape({32, 7}))
}); });
auto arg3 = cluster_0->parameter(3);
auto broadcast_1 = Broadcast::make(arg3, {1});
auto arg2 = cluster_0->parameter(2);
auto arg0 = cluster_0->parameter(0);
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment