Commit 2bb6d51d authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Merge pull request #53 from NervanaSystems/cyphers/view

Cyphers/view
parents 853b436a 520d9d5d
...@@ -15,19 +15,12 @@ get_filename_component( NGRAPH_INCLUDE_DIR . ABSOLUTE) ...@@ -15,19 +15,12 @@ get_filename_component( NGRAPH_INCLUDE_DIR . ABSOLUTE)
set(NGRAPH_INCLUDE_DIR "${NGRAPH_INCLUDE_DIR}" PARENT_SCOPE) set(NGRAPH_INCLUDE_DIR "${NGRAPH_INCLUDE_DIR}" PARENT_SCOPE)
set (SRC set (SRC
element_type.cpp
names.cpp
strides.cpp
tree.cpp tree.cpp
util.cpp util.cpp
log.cpp log.cpp
ngraph.cpp ops/function.cpp
ops/op.cpp
transformers/axes.cpp types/element_type.cpp
transformers/exop.cpp
transformers/mock_transformer.cpp
transformers/ndarray.cpp
transformers/op_graph.cpp
) )
# NOTE: We'd prefer to only have the .cpp files *in* the 'transformers' directory be compiled # NOTE: We'd prefer to only have the .cpp files *in* the 'transformers' directory be compiled
...@@ -53,20 +46,12 @@ endif() ...@@ -53,20 +46,12 @@ endif()
set(DEPLOY_SRC_HEADERS set(DEPLOY_SRC_HEADERS
element_type.hpp element_type.hpp
names.hpp
strides.hpp
tree.hpp tree.hpp
util.hpp util.hpp
uuid.hpp uuid.hpp
) )
set(DEPLOY_SRC_TRANSFORMERS_HEADERS set(DEPLOY_SRC_TRANSFORMERS_HEADERS
transformers/axes.hpp
transformers/exop.hpp
transformers/mock.hpp
transformers/mock_transformer.hpp
transformers/ndarray.hpp
transformers/op_graph.hpp
) )
install(TARGETS ngraph DESTINATION lib) install(TARGETS ngraph DESTINATION lib)
......
...@@ -19,16 +19,17 @@ ...@@ -19,16 +19,17 @@
class NGraph class NGraph
{ {
public: public:
void add_params(const std::vector<std::string>& paramList); void add_params(const std::vector<std::string>& paramList);
const std::vector<std::string>& get_params() const; const std::vector<std::string>& get_params() const;
std::string get_name() const { return "NGraph Implementation Object"; } std::string get_name() const { return "NGraph Implementation Object"; }
private: private:
std::vector<std::string> m_params; std::vector<std::string> m_params;
}; };
// Factory methods // Factory methods
extern "C" NGraph* create_ngraph_object(); extern "C" NGraph* create_ngraph_object();
extern "C" void destroy_ngraph_object(NGraph* pObj); extern "C" void destroy_ngraph_object(NGraph* pObj);
// FUnction pointers to the factory methods // FUnction pointers to the factory methods
typedef NGraph* (*CreateNGraphObjPfn)(); typedef NGraph* (*CreateNGraphObjPfn)();
......
...@@ -43,10 +43,10 @@ public: ...@@ -43,10 +43,10 @@ public:
private: private:
static std::map<std::string, ElementType> m_element_list; static std::map<std::string, ElementType> m_element_list;
size_t m_bitwidth; size_t m_bitwidth;
bool m_is_float; bool m_is_float;
bool m_is_signed; bool m_is_signed;
const std::string m_cname; const std::string m_cname;
}; };
extern const ngraph::ElementType element_type_float; extern const ngraph::ElementType element_type_float;
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/node.hpp"
#include "ngraph/op.hpp"
#include "ngraph/type.hpp"
namespace ngraph
{
class Function;
/**
** One parameter of a function. Within the function's graph
** the parameter is a node that represents the argument in a call.
**/
class Parameter : public Node
{
public:
using ptr = std::shared_ptr<Parameter>;
Parameter(Function& function, size_t index);
protected:
Function& m_function;
size_t m_index;
};
/**
** The result of a function. The ndoe addociated with the result
** supplies the return value when the function is called.
**/
class Result : public TypedValueMixin
{
public:
using ptr = std::shared_ptr<Result>;
Node::ptr value() const { return m_value; }
void value(const Node::ptr& value) { m_value = value; }
protected:
Node::ptr m_value;
};
/**
** A user-defined function.
**/
class Function : public Op
{
public:
Function(size_t n_parameters);
Result* result() { return &m_result; }
Parameter::ptr parameter(size_t i) { return m_parameters[i]; }
protected:
std::vector<Parameter::ptr> m_parameters;
Result m_result;
};
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
//
// The public API for ngraph++
//
#pragma once
#include "ngraph/element_type.hpp"
#include "ngraph/function.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/type.hpp"
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <vector>
#include "ngraph/type.hpp"
namespace ngraph
{
class Op;
/**
** Nodes are the backbone of the graph of Value dataflow. Every node has
** zero or more nodes as arguments and one value, which is either a tensor
** view or a (possibly empty) tuple of values.
**/
class Node : public TypedValueMixin
{
public:
using ptr = std::shared_ptr<Node>;
Node(const std::vector<Node::ptr>& arguments, ValueType::ptr type = nullptr)
: TypedValueMixin(type)
, m_arguments(arguments)
{
}
const std::vector<Node::ptr> arguments() const { return m_arguments; }
std::vector<Node::ptr> arguments() { return m_arguments; }
protected:
std::vector<Node::ptr> m_arguments;
};
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <memory>
#include "ngraph/node.hpp"
#include "ngraph/type.hpp"
namespace ngraph
{
class Op;
/**
** Call nodes are nodes whose value is the result of some operation, the op,
** applied to its arguments. We use the op as a callable to construct the
** call nodes.
**/
class Call : public Node
{
public:
std::shared_ptr<Op> op() const { return m_op; }
Call(const std::shared_ptr<Op>& op, const std::vector<Node::ptr>& arguments)
: Node(arguments, nullptr)
, m_op(op)
{
}
protected:
std::shared_ptr<Op> m_op;
};
/**
** The Op class provides the behavior for a Call.
**/
class Op
{
};
class Broadcast : public Op, public std::enable_shared_from_this<Broadcast>
{
protected:
class BroadcastCall : public Call
{
friend class Broadcast;
public:
BroadcastCall(const std::shared_ptr<Op>& op, const Node::ptr& arg, size_t axis)
: Call(op, {arg})
, m_axis(axis)
{
}
protected:
size_t m_axis;
};
public:
std::shared_ptr<BroadcastCall> operator()(const Node::ptr& tensor, size_t axis)
{
return std::make_shared<BroadcastCall>(shared_from_this(), tensor, axis);
}
};
namespace op
{
extern decltype(*std::shared_ptr<Broadcast>()) broadcast;
}
class Dot : public Op, public std::enable_shared_from_this<Dot>
{
public:
Call::ptr operator()(const Node::ptr& arg0, const Node::ptr& arg1)
{
return std::make_shared<Call>(shared_from_this(), std::vector<Node::ptr>{arg0, arg1});
}
};
namespace op
{
extern decltype(*std::shared_ptr<Dot>()) dot;
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <vector>
namespace ngraph
{
/**
** Holds the shape of a tensor view.
**/
class Shape
{
public:
/**
** \param sizes A sequence of sizes.
**/
Shape(const std::initializer_list<size_t>& sizes)
: m_sizes(sizes)
{
}
/**
** Conversion to a vector of sizes.
**/
operator const std::vector<size_t>&() const { return m_sizes; }
protected:
std::vector<size_t> m_sizes;
};
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <memory>
#include <vector>
#include "ngraph/element_type.hpp"
#include "ngraph/shape.hpp"
namespace ngraph
{
/**
** ValueType is
** TensorViewType
** | TupleType(ValueType[])
**/
class ValueType
{
public:
/**
** Preferred handle
**/
using ptr = std::shared_ptr<ValueType>;
};
/**
** Describes a tensor view; an element type and a shape.
**/
class TensorViewType : public ValueType
{
public:
/**
** Preferred handle
**/
using ptr = std::shared_ptr<TensorViewType>;
/**
** /param element_type The type of the tensor elements.
** /param shape The shape of the tensor.
**/
TensorViewType(const ElementType& element_type, const Shape& shape)
: m_element_type(element_type)
, m_shape(shape)
{
}
protected:
const ElementType& m_element_type;
Shape m_shape;
};
/**
** Describes a tuple of values; a vector of types
**/
class TupleType : public ValueType
{
public:
/**
** The preferred handle
**/
using ptr = std::shared_ptr<ValueType>;
/**
** Construct empty tuple and add value types later.
**/
TupleType() {}
/**
** /param element_types A vector of types for the tuple elements
**/
TupleType(const std::vector<ValueType::ptr>& element_types)
: m_element_types(element_types)
{
}
const std::vector<ValueType::ptr> element_types() const { return m_element_types; }
std::vector<ValueType::ptr> element_types() { return m_element_types; }
protected:
std::vector<ValueType::ptr> m_element_types;
};
/**
** Mixin for objects with type information
**/
class TypedValueMixin
{
public:
TypedValueMixin(const ValueType::ptr& type = nullptr)
: m_type(type)
{
}
/**
** Set the type
** /param type The new type
**/
void type(const ValueType::ptr& type) { m_type = type; }
/**
** Set the type to be a tensor view type
** /param element_type The type of the tensor elements
** /param shape The shape of the view
**/
void type(const ElementType& element_type, const Shape& shape)
{
m_type = std::make_shared<TensorViewType>(element_type, shape);
}
/**
** The type associated with this value.
**/
ValueType::ptr type() { return m_type; }
/**
** The type associated with this value.
**/
const ValueType::ptr type() const { return m_type; }
protected:
ValueType::ptr m_type;
};
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/ngraph.hpp"
using namespace std;
using namespace ngraph;
Parameter::Parameter(Function& function, size_t index)
: Node({})
, m_function(function)
, m_index(index)
{
}
Function::Function(size_t n_parameters)
: m_parameters(n_parameters)
{
for (int i = 0; i < n_parameters; i++)
{
m_parameters[i] = std::make_shared<Parameter>(*this, i);
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/ngraph.hpp"
using namespace ngraph;
decltype(*std::shared_ptr<Broadcast>()) ngraph::op::broadcast = *std::make_shared<Broadcast>();
decltype(*std::shared_ptr<Dot>()) ngraph::op::dot = *std::make_shared<Dot>();
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#include <cassert> #include <cassert>
#include <cmath> #include <cmath>
#include "element_type.hpp" #include "ngraph/element_type.hpp"
const ngraph::ElementType element_type_float = ngraph::ElementType(32, true, true, "float"); const ngraph::ElementType element_type_float = ngraph::ElementType(32, true, true, "float");
const ngraph::ElementType element_type_int8_t = ngraph::ElementType(8, false, true, "int8_t"); const ngraph::ElementType element_type_int8_t = ngraph::ElementType(8, false, true, "int8_t");
......
...@@ -22,16 +22,11 @@ include_directories( ...@@ -22,16 +22,11 @@ include_directories(
set (SRC set (SRC
main.cpp main.cpp
build_graph.cpp
util.cpp util.cpp
tensor.cpp tensor.cpp
exop.cpp
axes.cpp
element_type.cpp element_type.cpp
op_graph.cpp
uuid.cpp uuid.cpp
names.cpp
strides.cpp
ngraph.cpp
) )
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
using namespace std;
using namespace ngraph;
TEST(graph, build_simple)
{
// Function with 4 parameters
auto cluster_0 = make_shared<Function>(4);
cluster_0->result()->type(element_type_float, {32, 3});
cluster_0->parameter(0)->type(element_type_float, {7, 3});
cluster_0->parameter(1)->type(element_type_float, {3});
cluster_0->parameter(2)->type(element_type_float, {32, 7});
cluster_0->parameter(3)->type(element_type_float, {32, 7});
auto arg3 = cluster_0->parameter(3);
// call broadcast op on arg3, broadcasting on axis 1.
auto broadcast_1 = op::broadcast(arg3, 1);
auto arg2 = cluster_0->parameter(2);
auto arg0 = cluster_0->parameter(0);
// call dot op
auto dot = op::dot(arg2, arg0);
ASSERT_EQ(dot->dependents()[0], arg2);
ASSERT_EQ(dot->dependents()[1], arg0);
// Function returns tuple of dot and broadcast_1.
cluster_0->result()->value(dot);
ASSERT_EQ(cluster_0->result()->value(), dot);
}
...@@ -18,6 +18,6 @@ ...@@ -18,6 +18,6 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "element_type.hpp" #include "ngraph/element_type.hpp"
using namespace ngraph; using namespace ngraph;
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
using namespace std; using namespace std;
extern "C" int main(int argc, char** argv) int main(int argc, char** argv)
{ {
const char* exclude = "--gtest_filter=-benchmark.*"; const char* exclude = "--gtest_filter=-benchmark.*";
vector<char*> argv_vector; vector<char*> argv_vector;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment