Commit 271fb025 authored by Scott Cyphers's avatar Scott Cyphers

Organize files, add method to get op from call.

parent c1806e85
...@@ -15,12 +15,12 @@ get_filename_component( NGRAPH_INCLUDE_DIR . ABSOLUTE) ...@@ -15,12 +15,12 @@ get_filename_component( NGRAPH_INCLUDE_DIR . ABSOLUTE)
set(NGRAPH_INCLUDE_DIR "${NGRAPH_INCLUDE_DIR}" PARENT_SCOPE) set(NGRAPH_INCLUDE_DIR "${NGRAPH_INCLUDE_DIR}" PARENT_SCOPE)
set (SRC set (SRC
element_type.cpp
tree.cpp tree.cpp
util.cpp util.cpp
log.cpp log.cpp
values/function.cpp ops/function.cpp
values/op.cpp ops/op.cpp
types/element_type.cpp
) )
# NOTE: We'd prefer to only have the .cpp files *in* the 'transformers' directory be compiled # NOTE: We'd prefer to only have the .cpp files *in* the 'transformers' directory be compiled
......
...@@ -43,10 +43,10 @@ public: ...@@ -43,10 +43,10 @@ public:
private: private:
static std::map<std::string, ElementType> m_element_list; static std::map<std::string, ElementType> m_element_list;
size_t m_bitwidth; size_t m_bitwidth;
bool m_is_float; bool m_is_float;
bool m_is_signed; bool m_is_signed;
const std::string m_cname; const std::string m_cname;
}; };
extern const ngraph::ElementType element_type_float; extern const ngraph::ElementType element_type_float;
......
...@@ -14,9 +14,9 @@ ...@@ -14,9 +14,9 @@
#pragma once #pragma once
#include "values/node.hpp" #include "ngraph/node.hpp"
#include "values/op.hpp" #include "ngraph/op.hpp"
#include "values/type.hpp" #include "ngraph/type.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -25,53 +25,39 @@ namespace ngraph ...@@ -25,53 +25,39 @@ namespace ngraph
class Parameter : public Node class Parameter : public Node
{ {
public: public:
Parameter(Function& function, size_t index, const std::shared_ptr<ValueType>& type) using ptr = std::shared_ptr<Parameter>;
: Node({}, type)
, m_function(function) Parameter(Function& function, size_t index);
, m_index(index)
{
}
protected: protected:
Function& m_function; Function& m_function;
size_t m_index; size_t m_index;
}; };
class Result class Result : public TypedValueMixin
{ {
public: public:
void type(const std::shared_ptr<ValueType>& t) { m_type = t; } using ptr = std::shared_ptr<Result>;
void type(const ElementType& element_type, const Shape& shape)
{
m_type = std::make_shared<TensorViewType>(element_type, shape);
}
std::shared_ptr<ValueType> type() const { return m_type; }
std::shared_ptr<Node> value() const { return m_value; } Node::ptr value() const { return m_value; }
void value(const std::shared_ptr<Node>& value) { m_value = value; } void value(const Node::ptr& value) { m_value = value; }
protected: protected:
std::shared_ptr<ValueType> m_type; Node::ptr m_value;
std::shared_ptr<Node> m_value;
}; };
class Function class Function : public Op
{ {
public: public:
Function(size_t n_parameters) Function(size_t n_parameters);
: m_parameters(n_parameters)
{
}
Result* result() { return &m_result; } Result* result() { return &m_result; }
std::shared_ptr<Parameter> parameter(size_t i) { return m_parameters[i]; } std::shared_ptr<Parameter> parameter(size_t i) { return m_parameters[i]; }
protected: protected:
std::vector<std::shared_ptr<Parameter>> m_parameters; std::vector<Parameter::ptr> m_parameters;
Result m_result; Result m_result;
}; };
} // end namespace ngraph } // end namespace ngraph
\ No newline at end of file
...@@ -12,14 +12,15 @@ ...@@ -12,14 +12,15 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#pragma once //
// The public API for ngraph++
#include <algorithm> //
#include <memory>
#include <vector>
#include "values/type.hpp" #pragma once
namespace ngraph #include "ngraph/element_type.hpp"
{ #include "ngraph/function.hpp"
} #include "ngraph/node.hpp"
#include "ngraph/op.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/type.hpp"
...@@ -16,41 +16,37 @@ ...@@ -16,41 +16,37 @@
#include <vector> #include <vector>
#include "values/type.hpp" #include "ngraph/type.hpp"
namespace ngraph namespace ngraph
{ {
class Node class Op;
class Node : public TypedValueMixin
{ {
public: public:
Node(const std::vector<std::shared_ptr<Node>>& arguments, using ptr = std::shared_ptr<Node>;
std::shared_ptr<ValueType> type = 0)
Node(const std::vector<Node::ptr>& arguments, ValueType::ptr type = 0)
: m_arguments(arguments) : m_arguments(arguments)
, m_type(type) , TypedValueMixin(type)
{ {
} }
virtual ~Node() {} virtual ~Node() {}
virtual std::vector<std::shared_ptr<Node>> dependents() { return m_arguments; } virtual std::vector<Node::ptr> dependents() { return m_arguments; }
void type(const std::shared_ptr<ValueType>& t) { m_type = t; }
void type(const ElementType& element_type, const Shape& shape)
{
m_type = std::make_shared<TensorViewType>(element_type, shape);
}
std::shared_ptr<ValueType> type() const { return m_type; }
protected: protected:
std::vector<std::shared_ptr<Node>> m_arguments; std::vector<Node::ptr> m_arguments;
std::shared_ptr<ValueType> m_type;
}; };
class Call : public Node class Call : public Node
{ {
public:
virtual Op& op() const = 0;
protected: protected:
Call(const std::vector<std::shared_ptr<Node>>& arguments) Call(const std::vector<Node::ptr>& arguments)
: Node(arguments, 0) : Node(arguments, 0)
{ {
} }
......
...@@ -16,9 +16,8 @@ ...@@ -16,9 +16,8 @@
#include <memory> #include <memory>
#include "values/descriptor.hpp" #include "ngraph/node.hpp"
#include "values/node.hpp" #include "ngraph/type.hpp"
#include "values/type.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -33,18 +32,20 @@ namespace ngraph ...@@ -33,18 +32,20 @@ namespace ngraph
friend class Broadcast; friend class Broadcast;
public: public:
BroadcastCall(const std::shared_ptr<Node>& arg, size_t axis) BroadcastCall(const Node::ptr& arg, size_t axis)
: Call({arg}) : Call({arg})
, m_axis(axis) , m_axis(axis)
{ {
} }
Op& op() const override;
protected: protected:
size_t m_axis; size_t m_axis;
}; };
public: public:
std::shared_ptr<BroadcastCall> operator()(const std::shared_ptr<Node>& tensor, size_t axis) std::shared_ptr<BroadcastCall> operator()(const Node::ptr& tensor, size_t axis)
{ {
return std::make_shared<BroadcastCall>(tensor, axis); return std::make_shared<BroadcastCall>(tensor, axis);
} }
...@@ -62,15 +63,16 @@ namespace ngraph ...@@ -62,15 +63,16 @@ namespace ngraph
friend class Dot; friend class Dot;
public: public:
DotCall(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) DotCall(const std::shared_ptr<Node>& arg0, const Node::ptr& arg1)
: Call({arg0, arg1}) : Call({arg0, arg1})
{ {
} }
Op& op() const override;
}; };
public: public:
std::shared_ptr<DotCall> operator()(const std::shared_ptr<Node>& arg0, std::shared_ptr<DotCall> operator()(const Node::ptr& arg0, const Node::ptr& arg1)
const std::shared_ptr<Node>& arg1)
{ {
return std::make_shared<DotCall>(arg0, arg1); return std::make_shared<DotCall>(arg0, arg1);
} }
......
...@@ -14,56 +14,30 @@ ...@@ -14,56 +14,30 @@
#pragma once #pragma once
#include <memory>
#include <vector> #include <vector>
#include "element_type.hpp"
namespace ngraph namespace ngraph
{ {
/**
** Holds the shape of a tensor view.
**/
class Shape class Shape
{ {
public: public:
/**
** \param sizes A sequence of sizes.
**/
Shape(const std::initializer_list<size_t>& sizes) Shape(const std::initializer_list<size_t>& sizes)
: m_sizes(sizes) : m_sizes(sizes)
{ {
} }
protected: /**
std::vector<size_t> m_sizes; ** Conversion to a vector of sizes.
}; **/
operator const std::vector<size_t>&() const { return m_sizes; }
// ValueType is
// TensorViewType
// | TupleType(ValueType[])
class ValueType
{
};
class TensorViewType : public ValueType
{
public:
TensorViewType(const ElementType& element_type, const Shape& shape)
: m_element_type(element_type)
, m_shape(shape)
{
}
protected:
TensorViewType(const TensorViewType&) = delete;
const ElementType& m_element_type;
Shape m_shape;
};
class TupleType : public ValueType
{
public:
TupleType(const std::vector<std::shared_ptr<ValueType>>& element_types)
: m_element_types(element_types)
{
}
protected: protected:
std::vector<std::shared_ptr<ValueType>> m_element_types; std::vector<size_t> m_sizes;
}; };
} }
\ No newline at end of file
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <memory>
#include <vector>
#include "ngraph/element_type.hpp"
#include "ngraph/shape.hpp"
namespace ngraph
{
/**
** ValueType is
** TensorViewType
** | TupleType(ValueType[])
**/
class ValueType
{
public:
/**
** Preferred handle
**/
using ptr = std::shared_ptr<ValueType>;
};
/**
** Describes a tensor view; an element type and a shape.
**/
class TensorViewType : public ValueType
{
public:
/**
** Preferred handle
**/
using ptr = std::shared_ptr<TensorViewType>;
/**
** /param element_type The type of the tensor elements.
** /param shape The shape of the tensor.
**/
TensorViewType(const ElementType& element_type, const Shape& shape)
: m_element_type(element_type)
, m_shape(shape)
{
}
protected:
const ElementType& m_element_type;
Shape m_shape;
};
/**
** Describes a tuple of values; a vector of types
**/
class TupleType : public ValueType
{
public:
/**
** The preferred handle
**/
using ptr = std::shared_ptr<ValueType>;
/**
** Construct empty tuple and add value types later.
**/
TupleType() {}
/**
** /param element_types A vector of types for the tuple elements
**/
TupleType(const std::vector<ValueType::ptr>& element_types)
: m_element_types(element_types)
{
}
const std::vector<ValueType::ptr> element_types() const { return m_element_types; }
std::vector<ValueType::ptr> element_types() { return m_element_types; }
protected:
std::vector<ValueType::ptr> m_element_types;
};
/**
** Mixin for objects with type information
**/
class TypedValueMixin
{
public:
TypedValueMixin(const ValueType::ptr& type = 0)
: m_type(type)
{
}
/**
** Set the type
** /param type The new type
**/
void type(const ValueType::ptr& type) { m_type = type; }
/**
** Set the type to be a tensor view type
** /param element_type The type of the tensor elements
** /param shape The shape of the view
**/
void type(const ElementType& element_type, const Shape& shape)
{
m_type = TensorViewType::ptr::make_shared(element_type, shape);
}
/**
** The type associated with this value.
**/
ValueType::ptr type() { return m_type; }
/**
** The type associated with this value.
**/
const ValueType::ptr type() const { return m_type; }
protected:
ValueType::ptr m_type;
};
}
\ No newline at end of file
...@@ -12,7 +12,23 @@ ...@@ -12,7 +12,23 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include "values/function.hpp" #include "ngraph/ngraph.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
Parameter::Parameter(Function& function, size_t index)
: Node({})
, m_function(function)
, m_index(index)
{
}
Function::Function(size_t n_parameters)
: m_parameters(n_parameters)
{
for (int i = 0; i < n_parameters; i++)
{
m_parameters[i] = Parameter::ptr::make_shared(*this, i);
}
}
...@@ -12,9 +12,20 @@ ...@@ -12,9 +12,20 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include "values/op.hpp" #include "ngraph/ngraph.hpp"
using namespace ngraph; using namespace ngraph;
Broadcast ngraph::op::broadcast{}; Broadcast ngraph::op::broadcast{};
Dot ngraph::op::dot{};
\ No newline at end of file Op& ngraph::Broadcast::BroadcastCall::op() const
{
return op::broadcast;
}
Dot ngraph::op::dot{};
Op& ngraph::Dot::DotCall::op() const
{
return op::dot;
}
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#include <cassert> #include <cassert>
#include <cmath> #include <cmath>
#include "element_type.hpp" #include "ngraph/element_type.hpp"
const ngraph::ElementType element_type_float = ngraph::ElementType(32, true, true, "float"); const ngraph::ElementType element_type_float = ngraph::ElementType(32, true, true, "float");
const ngraph::ElementType element_type_int8_t = ngraph::ElementType(8, false, true, "int8_t"); const ngraph::ElementType element_type_int8_t = ngraph::ElementType(8, false, true, "int8_t");
......
...@@ -14,31 +14,31 @@ ...@@ -14,31 +14,31 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "values/type.hpp" #include "ngraph/ngraph.hpp"
#include "values/function.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
TEST(graph, build_simple) TEST(graph, build_simple)
{ {
// // Function with 4 parameters // Function with 4 parameters
// auto cluster_0 = make_shared<Function>(4); auto cluster_0 = make_shared<Function>(4);
// cluster_0->result()->type(element_type_float, Shape {32, 3}); cluster_0->result()->type(element_type_float, {32, 3});
// cluster_0->parameter(0)->type(element_type_float, Shape {Shape {7, 3}}); cluster_0->parameter(0)->type(element_type_float, {7, 3});
// cluster_0->parameter(1)->type(element_type_float, Shape {Shape {3}}); cluster_0->parameter(1)->type(element_type_float, {3});
// cluster_0->parameter(2)->type(element_type_float, Shape {Shape {32, 7}}); cluster_0->parameter(2)->type(element_type_float, {32, 7});
// cluster_0->parameter(3)->type(element_type_float, Shape {Shape {32, 7}}); cluster_0->parameter(3)->type(element_type_float, {32, 7});
// auto arg3 = cluster_0->parameter(3); auto arg3 = cluster_0->parameter(3);
// // call broadcast op on arg3, broadcasting on axis 1. // call broadcast op on arg3, broadcasting on axis 1.
// auto broadcast_1 = op::broadcast(arg3, 1); auto broadcast_1 = op::broadcast(arg3, 1);
// auto arg2 = cluster_0->parameter(2); auto arg2 = cluster_0->parameter(2);
// auto arg0 = cluster_0->parameter(0); auto arg0 = cluster_0->parameter(0);
// // call dot op // call dot op
// auto dot = op::dot(arg2, arg0); auto dot = op::dot(arg2, arg0);
// ASSERT_EQ(dot->dependents()[0], arg2); ASSERT_EQ(dot->dependents()[0], arg2);
// // Function returns tuple of dot and broadcast_1. ASSERT_EQ(dot->dependents()[1], arg0);
// cluster_0->result()->value(dot); // Function returns tuple of dot and broadcast_1.
cluster_0->result()->value(dot);
// ASSERT_EQ(cluster_0->result()->value(), dot); ASSERT_EQ(cluster_0->result()->value(), dot);
} }
...@@ -18,6 +18,6 @@ ...@@ -18,6 +18,6 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "element_type.hpp" #include "ngraph/element_type.hpp"
using namespace ngraph; using namespace ngraph;
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment