Commit 19f16bc1 authored by Scott Cyphers's avatar Scott Cyphers

Finish basic graph building.

parent d37359aa
...@@ -29,6 +29,7 @@ set (SRC ...@@ -29,6 +29,7 @@ set (SRC
transformers/op_graph.cpp transformers/op_graph.cpp
values/function.cpp values/function.cpp
values/op.cpp
) )
# NOTE: We'd prefer to only have the .cpp files *in* the 'transformers' directory be compiled # NOTE: We'd prefer to only have the .cpp files *in* the 'transformers' directory be compiled
......
...@@ -22,5 +22,4 @@ ...@@ -22,5 +22,4 @@
namespace ngraph namespace ngraph
{ {
} }
...@@ -20,39 +20,41 @@ ...@@ -20,39 +20,41 @@
namespace ngraph namespace ngraph
{ {
class Function; class Function;
class Parameter : public Node class Parameter : public Node
{ {
public: public:
Parameter(Function& function, size_t index, const std::shared_ptr<ValueType>& type) Parameter(Function& function, size_t index, const std::shared_ptr<ValueType>& type)
: Node(type) : Node({}, type)
, m_function(function) , m_function(function)
, m_index(index) , m_index(index)
{} {
}
protected: protected:
Function& m_function; Function& m_function;
size_t m_index; size_t m_index;
}; };
class Result { class Result
{
public: public:
void type(const std::shared_ptr<ValueType>& t){ void type(const std::shared_ptr<ValueType>& t) { m_type = t; }
m_type = t;
}
void type(const ElementType& element_type, const Shape& shape){ void type(const ElementType& element_type, const Shape& shape)
{
m_type = std::make_shared<TensorViewType>(element_type, shape); m_type = std::make_shared<TensorViewType>(element_type, shape);
} }
std::shared_ptr<ValueType> type() const { std::shared_ptr<ValueType> type() const { return m_type; }
return m_type;
} std::shared_ptr<Node> value() const { return m_value; }
void value(const std::shared_ptr<Node>& value) { m_value = value; }
protected: protected:
std::shared_ptr<ValueType> m_type; std::shared_ptr<ValueType> m_type;
std::shared_ptr<Node> m_value;
}; };
class Function class Function
...@@ -60,15 +62,12 @@ namespace ngraph ...@@ -60,15 +62,12 @@ namespace ngraph
public: public:
Function(size_t n_parameters) Function(size_t n_parameters)
: m_parameters(n_parameters) : m_parameters(n_parameters)
{} {
Result *result(){
return &m_result;
} }
std::shared_ptr<Parameter> parameter(size_t i){ Result* result() { return &m_result; }
return m_parameters[i];
} std::shared_ptr<Parameter> parameter(size_t i) { return m_parameters[i]; }
protected: protected:
std::vector<std::shared_ptr<Parameter>> m_parameters; std::vector<std::shared_ptr<Parameter>> m_parameters;
......
...@@ -20,34 +20,39 @@ ...@@ -20,34 +20,39 @@
namespace ngraph namespace ngraph
{ {
class Node class Node
{ {
public: public:
Node(std::shared_ptr<ValueType> type=0) Node(const std::vector<std::shared_ptr<Node>>& arguments,
: m_type(type) std::shared_ptr<ValueType> type = 0)
{} : m_arguments(arguments)
, m_type(type)
virtual ~Node(){} {
virtual std::vector<std::shared_ptr<Node>> dependents() {
return m_parameters;
} }
void type(const std::shared_ptr<ValueType>& t){ virtual ~Node() {}
m_type = t; virtual std::vector<std::shared_ptr<Node>> dependents() { return m_arguments; }
}
void type(const std::shared_ptr<ValueType>& t) { m_type = t; }
void type(const ElementType& element_type, const Shape& shape){ void type(const ElementType& element_type, const Shape& shape)
{
m_type = std::make_shared<TensorViewType>(element_type, shape); m_type = std::make_shared<TensorViewType>(element_type, shape);
} }
std::shared_ptr<ValueType> type() const { std::shared_ptr<ValueType> type() const { return m_type; }
return m_type;
}
protected: protected:
std::vector<std::shared_ptr<Node>> m_parameters; std::vector<std::shared_ptr<Node>> m_arguments;
std::shared_ptr<ValueType> m_type; std::shared_ptr<ValueType> m_type;
}; };
class Call : public Node
{
protected:
Call(const std::vector<std::shared_ptr<Node>>& arguments)
: Node(arguments, 0)
{
}
};
} }
\ No newline at end of file
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "values/op.hpp"
using namespace ngraph;
Broadcast ngraph::op::broadcast{};
Dot ngraph::op::dot{};
\ No newline at end of file
...@@ -17,15 +17,67 @@ ...@@ -17,15 +17,67 @@
#include <memory> #include <memory>
#include "values/descriptor.hpp" #include "values/descriptor.hpp"
#include "values/node.hpp"
#include "values/type.hpp" #include "values/type.hpp"
namespace ngraph namespace ngraph
{ {
class Call : public Node class Op
{ {
};
class Broadcast : public Op
{
class BroadcastCall : public Call
{
friend class Broadcast;
public:
BroadcastCall(const std::shared_ptr<Node>& arg, size_t axis)
: Call({arg})
, m_axis(axis)
{
}
protected: protected:
std::vector<std::shared_ptr<Node>> m_args; size_t m_axis;
};
public:
std::shared_ptr<BroadcastCall> operator()(const std::shared_ptr<Node>& tensor, size_t axis)
{
return std::make_shared<BroadcastCall>(tensor, axis);
}
}; };
namespace op
{
extern Broadcast broadcast;
}
class Dot : public Op
{
class DotCall : public Call
{
friend class Dot;
public:
DotCall(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: Call({arg0, arg1})
{
}
};
public:
std::shared_ptr<DotCall> operator()(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{
return std::make_shared<DotCall>(arg0, arg1);
}
};
namespace op
{
extern Dot dot;
}
} }
\ No newline at end of file
...@@ -19,14 +19,15 @@ ...@@ -19,14 +19,15 @@
#include "element_type.hpp" #include "element_type.hpp"
namespace ngraph { namespace ngraph
{
class Shape class Shape
{ {
public: public:
Shape(const std::initializer_list<size_t>& sizes) Shape(const std::initializer_list<size_t>& sizes)
: m_sizes(sizes) : m_sizes(sizes)
{} {
}
protected: protected:
std::vector<size_t> m_sizes; std::vector<size_t> m_sizes;
...@@ -45,7 +46,8 @@ namespace ngraph { ...@@ -45,7 +46,8 @@ namespace ngraph {
TensorViewType(const ElementType& element_type, const Shape& shape) TensorViewType(const ElementType& element_type, const Shape& shape)
: m_element_type(element_type) : m_element_type(element_type)
, m_shape(shape) , m_shape(shape)
{} {
}
protected: protected:
TensorViewType(const TensorViewType&) = delete; TensorViewType(const TensorViewType&) = delete;
...@@ -56,10 +58,10 @@ namespace ngraph { ...@@ -56,10 +58,10 @@ namespace ngraph {
class TupleType : public ValueType class TupleType : public ValueType
{ {
public: public:
TupleType(const std::vector<std::shared_ptr<ValueType>>& element_types) TupleType(const std::vector<std::shared_ptr<ValueType>>& element_types)
: m_element_types(element_types) : m_element_types(element_types)
{} {
}
protected: protected:
std::vector<std::shared_ptr<ValueType>> m_element_types; std::vector<std::shared_ptr<ValueType>> m_element_types;
......
...@@ -12,13 +12,15 @@ ...@@ -12,13 +12,15 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include "gtest/gtest.h"
#include "values/type.hpp" #include "values/type.hpp"
#include "values/function.hpp" #include "values/function.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
void build_simple_graph() TEST(graph, build_simple)
{ {
// Function with 4 parameters // Function with 4 parameters
auto cluster_0 = make_shared<Function>(4); auto cluster_0 = make_shared<Function>(4);
...@@ -29,11 +31,14 @@ void build_simple_graph() ...@@ -29,11 +31,14 @@ void build_simple_graph()
cluster_0->parameter(3)->type(element_type_float, Shape {Shape {32, 7}}); cluster_0->parameter(3)->type(element_type_float, Shape {Shape {32, 7}});
auto arg3 = cluster_0->parameter(3); auto arg3 = cluster_0->parameter(3);
// call broadcast op on arg3, broadcasting on axis 1. // call broadcast op on arg3, broadcasting on axis 1.
//auto broadcast_1 = op::broadcast(arg3, 1); auto broadcast_1 = op::broadcast(arg3, 1);
auto arg2 = cluster_0->parameter(2); auto arg2 = cluster_0->parameter(2);
auto arg0 = cluster_0->parameter(0); auto arg0 = cluster_0->parameter(0);
// call dot op // call dot op
//auto dot = op::dot(arg2, arg0); auto dot = op::dot(arg2, arg0);
ASSERT_EQ(dot->dependents()[0], arg2);
// Function returns tuple of dot and broadcast_1. // Function returns tuple of dot and broadcast_1.
//cluster_0.result->value(op::tuple(dot, broadcast_1)); cluster_0->result()->value(dot);
ASSERT_EQ(cluster_0->result()->value(), dot);
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment