Commit 19f16bc1 authored by Scott Cyphers's avatar Scott Cyphers

Finish basic graph building.

parent d37359aa
......@@ -29,6 +29,7 @@ set (SRC
transformers/op_graph.cpp
values/function.cpp
values/op.cpp
)
# NOTE: We'd prefer to only have the .cpp files *in* the 'transformers' directory be compiled
......
......@@ -20,7 +20,6 @@
#include "values/type.hpp"
namespace ngraph
namespace ngraph
{
}
......@@ -18,61 +18,60 @@
#include "values/op.hpp"
#include "values/type.hpp"
namespace ngraph
namespace ngraph
{
class Function;
class Parameter : public Node
{
public:
Parameter(Function& function, size_t index, const std::shared_ptr<ValueType>& type)
: Node(type)
, m_function(function)
, m_index(index)
{}
: Node({}, type)
, m_function(function)
, m_index(index)
{
}
protected:
Function& m_function;
size_t m_index;
size_t m_index;
};
class Result {
class Result
{
public:
void type(const std::shared_ptr<ValueType>& t){
m_type = t;
}
void type(const std::shared_ptr<ValueType>& t) { m_type = t; }
void type(const ElementType& element_type, const Shape& shape){
void type(const ElementType& element_type, const Shape& shape)
{
m_type = std::make_shared<TensorViewType>(element_type, shape);
}
std::shared_ptr<ValueType> type() const {
return m_type;
}
std::shared_ptr<ValueType> type() const { return m_type; }
std::shared_ptr<Node> value() const { return m_value; }
void value(const std::shared_ptr<Node>& value) { m_value = value; }
protected:
std::shared_ptr<ValueType> m_type;
std::shared_ptr<Node> m_value;
};
class Function
{
public:
Function(size_t n_parameters)
: m_parameters(n_parameters)
{}
Result *result(){
return &m_result;
: m_parameters(n_parameters)
{
}
std::shared_ptr<Parameter> parameter(size_t i){
return m_parameters[i];
}
Result* result() { return &m_result; }
std::shared_ptr<Parameter> parameter(size_t i) { return m_parameters[i]; }
protected:
std::vector<std::shared_ptr<Parameter>> m_parameters;
Result m_result;
Result m_result;
};
} // end namespace ngraph
\ No newline at end of file
......@@ -20,34 +20,39 @@
namespace ngraph
{
class Node
{
public:
Node(std::shared_ptr<ValueType> type=0)
: m_type(type)
{}
virtual ~Node(){}
virtual std::vector<std::shared_ptr<Node>> dependents() {
return m_parameters;
Node(const std::vector<std::shared_ptr<Node>>& arguments,
std::shared_ptr<ValueType> type = 0)
: m_arguments(arguments)
, m_type(type)
{
}
void type(const std::shared_ptr<ValueType>& t){
m_type = t;
}
virtual ~Node() {}
virtual std::vector<std::shared_ptr<Node>> dependents() { return m_arguments; }
void type(const ElementType& element_type, const Shape& shape){
void type(const std::shared_ptr<ValueType>& t) { m_type = t; }
void type(const ElementType& element_type, const Shape& shape)
{
m_type = std::make_shared<TensorViewType>(element_type, shape);
}
std::shared_ptr<ValueType> type() const {
return m_type;
}
std::shared_ptr<ValueType> type() const { return m_type; }
protected:
std::vector<std::shared_ptr<Node>> m_parameters;
std::shared_ptr<ValueType> m_type;
std::vector<std::shared_ptr<Node>> m_arguments;
std::shared_ptr<ValueType> m_type;
};
class Call : public Node
{
protected:
Call(const std::vector<std::shared_ptr<Node>>& arguments)
: Node(arguments, 0)
{
}
};
}
\ No newline at end of file
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "values/op.hpp"
using namespace ngraph;
Broadcast ngraph::op::broadcast{};
Dot ngraph::op::dot{};
\ No newline at end of file
......@@ -17,15 +17,67 @@
#include <memory>
#include "values/descriptor.hpp"
#include "values/node.hpp"
#include "values/type.hpp"
namespace ngraph
{
class Call : public Node
class Op
{
protected:
std::vector<std::shared_ptr<Node>> m_args;
};
class Broadcast : public Op
{
class BroadcastCall : public Call
{
friend class Broadcast;
public:
BroadcastCall(const std::shared_ptr<Node>& arg, size_t axis)
: Call({arg})
, m_axis(axis)
{
}
protected:
size_t m_axis;
};
public:
std::shared_ptr<BroadcastCall> operator()(const std::shared_ptr<Node>& tensor, size_t axis)
{
return std::make_shared<BroadcastCall>(tensor, axis);
}
};
namespace op
{
extern Broadcast broadcast;
}
class Dot : public Op
{
class DotCall : public Call
{
friend class Dot;
public:
DotCall(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: Call({arg0, arg1})
{
}
};
public:
std::shared_ptr<DotCall> operator()(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{
return std::make_shared<DotCall>(arg0, arg1);
}
};
namespace op
{
extern Dot dot;
}
}
\ No newline at end of file
......@@ -19,14 +19,15 @@
#include "element_type.hpp"
namespace ngraph {
namespace ngraph
{
class Shape
{
public:
Shape(const std::initializer_list<size_t>& sizes)
: m_sizes(sizes)
{}
: m_sizes(sizes)
{
}
protected:
std::vector<size_t> m_sizes;
......@@ -43,23 +44,24 @@ namespace ngraph {
{
public:
TensorViewType(const ElementType& element_type, const Shape& shape)
: m_element_type(element_type)
, m_shape(shape)
{}
: m_element_type(element_type)
, m_shape(shape)
{
}
protected:
TensorViewType(const TensorViewType&) = delete;
const ElementType& m_element_type;
Shape m_shape;
Shape m_shape;
};
class TupleType : public ValueType
{
public:
TupleType(const std::vector<std::shared_ptr<ValueType>>& element_types)
: m_element_types(element_types)
{}
: m_element_types(element_types)
{
}
protected:
std::vector<std::shared_ptr<ValueType>> m_element_types;
......
......@@ -12,13 +12,15 @@
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "gtest/gtest.h"
#include "values/type.hpp"
#include "values/function.hpp"
using namespace std;
using namespace ngraph;
void build_simple_graph()
TEST(graph, build_simple)
{
// Function with 4 parameters
auto cluster_0 = make_shared<Function>(4);
......@@ -29,11 +31,14 @@ void build_simple_graph()
cluster_0->parameter(3)->type(element_type_float, Shape {Shape {32, 7}});
auto arg3 = cluster_0->parameter(3);
// call broadcast op on arg3, broadcasting on axis 1.
//auto broadcast_1 = op::broadcast(arg3, 1);
auto broadcast_1 = op::broadcast(arg3, 1);
auto arg2 = cluster_0->parameter(2);
auto arg0 = cluster_0->parameter(0);
// call dot op
//auto dot = op::dot(arg2, arg0);
auto dot = op::dot(arg2, arg0);
ASSERT_EQ(dot->dependents()[0], arg2);
// Function returns tuple of dot and broadcast_1.
//cluster_0.result->value(op::tuple(dot, broadcast_1));
cluster_0->result()->value(dot);
ASSERT_EQ(cluster_0->result()->value(), dot);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment