Commit 271fb025 authored by Scott Cyphers's avatar Scott Cyphers

Organize files, add method to get op from call.

parent c1806e85
......@@ -15,12 +15,12 @@ get_filename_component( NGRAPH_INCLUDE_DIR . ABSOLUTE)
set(NGRAPH_INCLUDE_DIR "${NGRAPH_INCLUDE_DIR}" PARENT_SCOPE)
set (SRC
element_type.cpp
tree.cpp
util.cpp
log.cpp
values/function.cpp
values/op.cpp
ops/function.cpp
ops/op.cpp
types/element_type.cpp
)
# NOTE: We'd prefer to only have the .cpp files *in* the 'transformers' directory be compiled
......
......@@ -14,9 +14,9 @@
#pragma once
#include "values/node.hpp"
#include "values/op.hpp"
#include "values/type.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op.hpp"
#include "ngraph/type.hpp"
namespace ngraph
{
......@@ -25,52 +25,38 @@ namespace ngraph
class Parameter : public Node
{
public:
Parameter(Function& function, size_t index, const std::shared_ptr<ValueType>& type)
: Node({}, type)
, m_function(function)
, m_index(index)
{
}
using ptr = std::shared_ptr<Parameter>;
Parameter(Function& function, size_t index);
protected:
Function& m_function;
size_t m_index;
};
class Result
class Result : public TypedValueMixin
{
public:
void type(const std::shared_ptr<ValueType>& t) { m_type = t; }
void type(const ElementType& element_type, const Shape& shape)
{
m_type = std::make_shared<TensorViewType>(element_type, shape);
}
std::shared_ptr<ValueType> type() const { return m_type; }
using ptr = std::shared_ptr<Result>;
std::shared_ptr<Node> value() const { return m_value; }
void value(const std::shared_ptr<Node>& value) { m_value = value; }
Node::ptr value() const { return m_value; }
void value(const Node::ptr& value) { m_value = value; }
protected:
std::shared_ptr<ValueType> m_type;
std::shared_ptr<Node> m_value;
Node::ptr m_value;
};
class Function
class Function : public Op
{
public:
Function(size_t n_parameters)
: m_parameters(n_parameters)
{
}
Function(size_t n_parameters);
Result* result() { return &m_result; }
std::shared_ptr<Parameter> parameter(size_t i) { return m_parameters[i]; }
protected:
std::vector<std::shared_ptr<Parameter>> m_parameters;
std::vector<Parameter::ptr> m_parameters;
Result m_result;
};
......
......@@ -12,14 +12,15 @@
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <algorithm>
#include <memory>
#include <vector>
//
// The public API for ngraph++
//
#include "values/type.hpp"
#pragma once
namespace ngraph
{
}
#include "ngraph/element_type.hpp"
#include "ngraph/function.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/type.hpp"
......@@ -16,41 +16,37 @@
#include <vector>
#include "values/type.hpp"
#include "ngraph/type.hpp"
namespace ngraph
{
class Node
class Op;
class Node : public TypedValueMixin
{
public:
Node(const std::vector<std::shared_ptr<Node>>& arguments,
std::shared_ptr<ValueType> type = 0)
using ptr = std::shared_ptr<Node>;
Node(const std::vector<Node::ptr>& arguments, ValueType::ptr type = 0)
: m_arguments(arguments)
, m_type(type)
, TypedValueMixin(type)
{
}
virtual ~Node() {}
virtual std::vector<std::shared_ptr<Node>> dependents() { return m_arguments; }
void type(const std::shared_ptr<ValueType>& t) { m_type = t; }
void type(const ElementType& element_type, const Shape& shape)
{
m_type = std::make_shared<TensorViewType>(element_type, shape);
}
std::shared_ptr<ValueType> type() const { return m_type; }
virtual std::vector<Node::ptr> dependents() { return m_arguments; }
protected:
std::vector<std::shared_ptr<Node>> m_arguments;
std::shared_ptr<ValueType> m_type;
std::vector<Node::ptr> m_arguments;
};
class Call : public Node
{
public:
virtual Op& op() const = 0;
protected:
Call(const std::vector<std::shared_ptr<Node>>& arguments)
Call(const std::vector<Node::ptr>& arguments)
: Node(arguments, 0)
{
}
......
......@@ -16,9 +16,8 @@
#include <memory>
#include "values/descriptor.hpp"
#include "values/node.hpp"
#include "values/type.hpp"
#include "ngraph/node.hpp"
#include "ngraph/type.hpp"
namespace ngraph
{
......@@ -33,18 +32,20 @@ namespace ngraph
friend class Broadcast;
public:
BroadcastCall(const std::shared_ptr<Node>& arg, size_t axis)
BroadcastCall(const Node::ptr& arg, size_t axis)
: Call({arg})
, m_axis(axis)
{
}
Op& op() const override;
protected:
size_t m_axis;
};
public:
std::shared_ptr<BroadcastCall> operator()(const std::shared_ptr<Node>& tensor, size_t axis)
std::shared_ptr<BroadcastCall> operator()(const Node::ptr& tensor, size_t axis)
{
return std::make_shared<BroadcastCall>(tensor, axis);
}
......@@ -62,15 +63,16 @@ namespace ngraph
friend class Dot;
public:
DotCall(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
DotCall(const std::shared_ptr<Node>& arg0, const Node::ptr& arg1)
: Call({arg0, arg1})
{
}
Op& op() const override;
};
public:
std::shared_ptr<DotCall> operator()(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
std::shared_ptr<DotCall> operator()(const Node::ptr& arg0, const Node::ptr& arg1)
{
return std::make_shared<DotCall>(arg0, arg1);
}
......
......@@ -14,56 +14,30 @@
#pragma once
#include <memory>
#include <vector>
#include "element_type.hpp"
namespace ngraph
{
/**
** Holds the shape of a tensor view.
**/
class Shape
{
public:
/**
** \param sizes A sequence of sizes.
**/
Shape(const std::initializer_list<size_t>& sizes)
: m_sizes(sizes)
{
}
protected:
std::vector<size_t> m_sizes;
};
// ValueType is
// TensorViewType
// | TupleType(ValueType[])
class ValueType
{
};
class TensorViewType : public ValueType
{
public:
TensorViewType(const ElementType& element_type, const Shape& shape)
: m_element_type(element_type)
, m_shape(shape)
{
}
protected:
TensorViewType(const TensorViewType&) = delete;
const ElementType& m_element_type;
Shape m_shape;
};
class TupleType : public ValueType
{
public:
TupleType(const std::vector<std::shared_ptr<ValueType>>& element_types)
: m_element_types(element_types)
{
}
/**
** Conversion to a vector of sizes.
**/
operator const std::vector<size_t>&() const { return m_sizes; }
protected:
std::vector<std::shared_ptr<ValueType>> m_element_types;
std::vector<size_t> m_sizes;
};
}
\ No newline at end of file
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <memory>
#include <vector>
#include "ngraph/element_type.hpp"
#include "ngraph/shape.hpp"
namespace ngraph
{
/**
** ValueType is
** TensorViewType
** | TupleType(ValueType[])
**/
class ValueType
{
public:
/**
** Preferred handle
**/
using ptr = std::shared_ptr<ValueType>;
};
/**
** Describes a tensor view; an element type and a shape.
**/
class TensorViewType : public ValueType
{
public:
/**
** Preferred handle
**/
using ptr = std::shared_ptr<TensorViewType>;
/**
** /param element_type The type of the tensor elements.
** /param shape The shape of the tensor.
**/
TensorViewType(const ElementType& element_type, const Shape& shape)
: m_element_type(element_type)
, m_shape(shape)
{
}
protected:
const ElementType& m_element_type;
Shape m_shape;
};
/**
** Describes a tuple of values; a vector of types
**/
class TupleType : public ValueType
{
public:
/**
** The preferred handle
**/
using ptr = std::shared_ptr<ValueType>;
/**
** Construct empty tuple and add value types later.
**/
TupleType() {}
/**
** /param element_types A vector of types for the tuple elements
**/
TupleType(const std::vector<ValueType::ptr>& element_types)
: m_element_types(element_types)
{
}
const std::vector<ValueType::ptr> element_types() const { return m_element_types; }
std::vector<ValueType::ptr> element_types() { return m_element_types; }
protected:
std::vector<ValueType::ptr> m_element_types;
};
/**
** Mixin for objects with type information
**/
class TypedValueMixin
{
public:
TypedValueMixin(const ValueType::ptr& type = 0)
: m_type(type)
{
}
/**
** Set the type
** /param type The new type
**/
void type(const ValueType::ptr& type) { m_type = type; }
/**
** Set the type to be a tensor view type
** /param element_type The type of the tensor elements
** /param shape The shape of the view
**/
void type(const ElementType& element_type, const Shape& shape)
{
m_type = TensorViewType::ptr::make_shared(element_type, shape);
}
/**
** The type associated with this value.
**/
ValueType::ptr type() { return m_type; }
/**
** The type associated with this value.
**/
const ValueType::ptr type() const { return m_type; }
protected:
ValueType::ptr m_type;
};
}
\ No newline at end of file
......@@ -12,7 +12,23 @@
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "values/function.hpp"
#include "ngraph/ngraph.hpp"
using namespace std;
using namespace ngraph;
Parameter::Parameter(Function& function, size_t index)
: Node({})
, m_function(function)
, m_index(index)
{
}
Function::Function(size_t n_parameters)
: m_parameters(n_parameters)
{
for (int i = 0; i < n_parameters; i++)
{
m_parameters[i] = Parameter::ptr::make_shared(*this, i);
}
}
......@@ -12,9 +12,20 @@
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "values/op.hpp"
#include "ngraph/ngraph.hpp"
using namespace ngraph;
Broadcast ngraph::op::broadcast{};
Op& ngraph::Broadcast::BroadcastCall::op() const
{
return op::broadcast;
}
Dot ngraph::op::dot{};
Op& ngraph::Dot::DotCall::op() const
{
return op::dot;
}
......@@ -15,7 +15,7 @@
#include <cassert>
#include <cmath>
#include "element_type.hpp"
#include "ngraph/element_type.hpp"
const ngraph::ElementType element_type_float = ngraph::ElementType(32, true, true, "float");
const ngraph::ElementType element_type_int8_t = ngraph::ElementType(8, false, true, "int8_t");
......
......@@ -14,31 +14,31 @@
#include "gtest/gtest.h"
#include "values/type.hpp"
#include "values/function.hpp"
#include "ngraph/ngraph.hpp"
using namespace std;
using namespace ngraph;
TEST(graph, build_simple)
{
// // Function with 4 parameters
// auto cluster_0 = make_shared<Function>(4);
// cluster_0->result()->type(element_type_float, Shape {32, 3});
// cluster_0->parameter(0)->type(element_type_float, Shape {Shape {7, 3}});
// cluster_0->parameter(1)->type(element_type_float, Shape {Shape {3}});
// cluster_0->parameter(2)->type(element_type_float, Shape {Shape {32, 7}});
// cluster_0->parameter(3)->type(element_type_float, Shape {Shape {32, 7}});
// auto arg3 = cluster_0->parameter(3);
// // call broadcast op on arg3, broadcasting on axis 1.
// auto broadcast_1 = op::broadcast(arg3, 1);
// auto arg2 = cluster_0->parameter(2);
// auto arg0 = cluster_0->parameter(0);
// // call dot op
// auto dot = op::dot(arg2, arg0);
// ASSERT_EQ(dot->dependents()[0], arg2);
// // Function returns tuple of dot and broadcast_1.
// cluster_0->result()->value(dot);
// Function with 4 parameters
auto cluster_0 = make_shared<Function>(4);
cluster_0->result()->type(element_type_float, {32, 3});
cluster_0->parameter(0)->type(element_type_float, {7, 3});
cluster_0->parameter(1)->type(element_type_float, {3});
cluster_0->parameter(2)->type(element_type_float, {32, 7});
cluster_0->parameter(3)->type(element_type_float, {32, 7});
auto arg3 = cluster_0->parameter(3);
// call broadcast op on arg3, broadcasting on axis 1.
auto broadcast_1 = op::broadcast(arg3, 1);
auto arg2 = cluster_0->parameter(2);
auto arg0 = cluster_0->parameter(0);
// call dot op
auto dot = op::dot(arg2, arg0);
ASSERT_EQ(dot->dependents()[0], arg2);
ASSERT_EQ(dot->dependents()[1], arg0);
// Function returns tuple of dot and broadcast_1.
cluster_0->result()->value(dot);
// ASSERT_EQ(cluster_0->result()->value(), dot);
ASSERT_EQ(cluster_0->result()->value(), dot);
}
......@@ -18,6 +18,6 @@
#include "gtest/gtest.h"
#include "element_type.hpp"
#include "ngraph/element_type.hpp"
using namespace ngraph;
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment