Commit 8827be11 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Merge pull request #68 from NervanaSystems/cyphers/function

Cyphers/function
parents 2c30e819 304f1219
...@@ -20,6 +20,7 @@ set (SRC ...@@ -20,6 +20,7 @@ set (SRC
log.cpp log.cpp
ops/function.cpp ops/function.cpp
ops/op.cpp ops/op.cpp
ops/parameter.cpp
types/element_type.cpp types/element_type.cpp
) )
......
...@@ -16,65 +16,39 @@ ...@@ -16,65 +16,39 @@
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op.hpp" #include "ngraph/op.hpp"
#include "ngraph/parameter.hpp"
#include "ngraph/type.hpp" #include "ngraph/type.hpp"
namespace ngraph namespace ngraph
{ {
class Function;
/**
** One parameter of a function. Within the function's graph
** the parameter is a node that represents the argument in a call.
**/
class Parameter : public Node
{
public:
using ptr = std::shared_ptr<Parameter>;
Parameter(Function& function, size_t index);
std::string description() const override { return "Parameter"; }
virtual void propagate_types() override;
protected:
Function& m_function;
size_t m_index;
};
/**
** The result of a function. The ndoe addociated with the result
** supplies the return value when the function is called.
**/
class Result : public TypedValueMixin
{
public:
using ptr = std::shared_ptr<Result>;
Node::ptr value() const { return m_value; }
void value(const Node::ptr& value) { m_value = value; }
protected:
Node::ptr m_value;
};
/** /**
** A user-defined function. ** A user-defined function.
**/ **/
class Function class Function
{ {
public: public:
Function(size_t n_parameters); Function(const Node::ptr& result,
const std::vector<std::shared_ptr<Parameter>>& parameters);
Result* result() { return &m_result; } Node::ptr result() { return m_result; }
Parameter::ptr parameter(size_t i) { return m_parameters[i]; } Parameter::ptr parameter(size_t i) { return m_parameters[i]; }
std::string name() const { return m_name; } std::string name() const { return m_name; }
protected: protected:
std::vector<Parameter::ptr> m_parameters; Node::ptr m_result;
Result m_result; std::vector<std::shared_ptr<ngraph::Parameter>> m_parameters;
std::string m_name; std::string m_name;
}; };
namespace op
{
std::shared_ptr<Function>
function(const Node::ptr& result,
const std::initializer_list<std::shared_ptr<Parameter>>& parameters);
std::shared_ptr<Function>
function(const Node::ptr& result,
const std::vector<std::shared_ptr<Parameter>>& parameters);
}
} }
...@@ -23,5 +23,6 @@ ...@@ -23,5 +23,6 @@
#include "ngraph/function.hpp" #include "ngraph/function.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op.hpp" #include "ngraph/op.hpp"
#include "ngraph/parameter.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "ngraph/type.hpp" #include "ngraph/type.hpp"
...@@ -69,7 +69,10 @@ namespace ngraph ...@@ -69,7 +69,10 @@ namespace ngraph
** will be used by the pattern matcher when comparing a pattern ** will be used by the pattern matcher when comparing a pattern
** graph against the graph. ** graph against the graph.
**/ **/
bool is_same_op_type(const Node::ptr& node) const { return typeid(*this) == typeid(*node.get()); } bool is_same_op_type(const Node::ptr& node) const
{
return typeid(*this) == typeid(*node.get());
}
protected: protected:
std::vector<Node::ptr> m_arguments; std::vector<Node::ptr> m_arguments;
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <memory> #include <memory>
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/parameter.hpp"
#include "ngraph/type.hpp" #include "ngraph/type.hpp"
namespace ngraph namespace ngraph
...@@ -72,7 +73,6 @@ namespace ngraph ...@@ -72,7 +73,6 @@ namespace ngraph
class Op : public Node class Op : public Node
{ {
public: public:
Op(const std::vector<Node::ptr>& arguments) Op(const std::vector<Node::ptr>& arguments)
: Node(arguments, nullptr) : Node(arguments, nullptr)
{ {
...@@ -85,8 +85,8 @@ namespace ngraph ...@@ -85,8 +85,8 @@ namespace ngraph
**/ **/
class FunctionOp : public Op class FunctionOp : public Op
{ {
virtual std::string description() const override { return "FunctionOp"; } virtual std::string description() const override { return "FunctionOp"; }
protected: protected:
Node::ptr m_function; Node::ptr m_function;
}; };
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/node.hpp"
#include "ngraph/type.hpp"
namespace ngraph
{
class Function;
/**
** Parameters are nodes that represent the arguments that will be passed to user-defined functions.
** Function creation requires a sequence of parameters.
** Basic graph operations do not need parameters attached to a function.
**/
class Parameter : public Node
{
friend class Function;
protected:
// Called by the Function constructor to associate this parameter with the function.
// It is an error to try to associate a parameter with more than one function.
void assign_function(Function* function, size_t index);
public:
Parameter(const ValueType::ptr& value_type);
std::string description() const override { return "Parameter"; }
virtual void propagate_types() override;
protected:
Function* m_function;
size_t m_index;
};
namespace op
{
/// Factory for frameworks
std::shared_ptr<ngraph::Parameter> parameter(const ValueType::ptr& value_type = nullptr);
/// Convenience factory for tests
std::shared_ptr<ngraph::Parameter> parameter(const ngraph::element::Type element_type,
const Shape& shape);
}
}
...@@ -17,27 +17,27 @@ ...@@ -17,27 +17,27 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
Parameter::Parameter(Function& function, size_t index) Function::Function(const Node::ptr& result,
: Node({}) const std::vector<std::shared_ptr<ngraph::Parameter>>& parameters)
, m_function(function) : m_result(result)
, m_index(index) , m_parameters(parameters)
, m_name("Function")
{ {
size_t i = 0;
for (auto parameter : parameters)
{
parameter->assign_function(this, i++);
}
} }
void Parameter::propagate_types() shared_ptr<Function> ngraph::op::function(const Node::ptr& result,
const initializer_list<shared_ptr<Parameter>>& parameters)
{ {
if (m_type == nullptr) return make_shared<Function>(result, parameters);
{
throw ngraph_error{"Unitialized parameter"};
}
} }
Function::Function(size_t n_parameters) shared_ptr<Function> ngraph::op::function(const Node::ptr& result,
: m_parameters(n_parameters) const vector<shared_ptr<Parameter>>& parameters)
, m_name("Function")
{ {
for (int i = 0; i < n_parameters; i++) return make_shared<Function>(result, parameters);
{
m_parameters[i] = std::make_shared<Parameter>(*this, i);
}
} }
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/ngraph.hpp"
using namespace std;
using namespace ngraph;
Parameter::Parameter(const ValueType::ptr& value_type)
: Node({}, value_type)
, m_function(nullptr)
, m_index(0)
{
}
void Parameter::assign_function(Function* function, size_t index)
{
if (nullptr != m_function)
{
throw ngraph_error("Re-assigning function to a parameter.");
}
m_function = function;
m_index = index;
}
void Parameter::propagate_types() {}
shared_ptr<Parameter> ngraph::op::parameter(const ValueType::ptr& value_type)
{
return make_shared<Parameter>(value_type);
}
shared_ptr<Parameter> ngraph::op::parameter(const ngraph::element::Type element_type,
const Shape& shape)
{
return make_shared<Parameter>(make_shared<TensorViewType>(element_type, shape));
}
...@@ -22,23 +22,18 @@ using namespace ngraph; ...@@ -22,23 +22,18 @@ using namespace ngraph;
TEST(build_graph, build_simple) TEST(build_graph, build_simple)
{ {
// Function with 4 parameters // Function with 4 parameters
auto cluster_0 = make_shared<Function>(4); auto arg0 = op::parameter(element::float32_t, {7, 3});
cluster_0->result()->type(element::float32_t, {32, 3}); auto arg1 = op::parameter(element::float32_t, {3});
cluster_0->parameter(0)->type(element::float32_t, {7, 3}); auto arg2 = op::parameter(element::float32_t, {32, 7});
cluster_0->parameter(1)->type(element::float32_t, {3}); auto arg3 = op::parameter(element::float32_t, {32, 7});
cluster_0->parameter(2)->type(element::float32_t, {32, 7});
cluster_0->parameter(3)->type(element::float32_t, {32, 7});
auto arg3 = cluster_0->parameter(3);
auto broadcast_1 = op::broadcast(arg3, {10, 32, 7}, {0}); auto broadcast_1 = op::broadcast(arg3, {10, 32, 7}, {0});
auto arg2 = cluster_0->parameter(2);
auto arg0 = cluster_0->parameter(0);
auto dot = op::dot(arg2, arg0); auto dot = op::dot(arg2, arg0);
ASSERT_EQ(dot->arguments()[0], arg2); ASSERT_EQ(dot->arguments()[0], arg2);
ASSERT_EQ(dot->arguments()[1], arg0); ASSERT_EQ(dot->arguments()[1], arg0);
// Function returns tuple of dot and broadcast_1.
cluster_0->result()->value(dot);
ASSERT_EQ(cluster_0->result()->value(), dot); auto cluster_0 = op::function(dot, {arg0, arg1, arg2, arg3});
ASSERT_EQ(cluster_0->result(), dot);
} }
// Check upcasting from ValueType. // Check upcasting from ValueType.
...@@ -62,20 +57,14 @@ TEST(build_graph, as_type) ...@@ -62,20 +57,14 @@ TEST(build_graph, as_type)
// Check node comparisons // Check node comparisons
TEST(build_graph, node_comparison) TEST(build_graph, node_comparison)
{ {
auto fun = make_shared<Function>(3); auto arg0 = op::parameter(element::float32_t, {32, 3});
fun->parameter(0)->type(element::float32_t, {32, 3}); auto arg1 = op::parameter(element::float32_t, {3});
fun->parameter(1)->type(element::float32_t, {3}); auto arg2 = op::parameter(element::float32_t, {32});
fun->parameter(2)->type(element::float32_t, {32});
auto arg0 = fun->parameter(0);
auto arg1 = fun->parameter(1);
auto arg2 = fun->parameter(2);
auto dot = op::dot(arg0, arg1); auto dot = op::dot(arg0, arg1);
auto add = op::add(dot, arg2); auto add = op::add(dot, arg2);
auto pattern = make_shared<Function>(1); auto parg = op::parameter(element::float32_t, {});
pattern->parameter(0)->type(element::float32_t, {});
auto parg = pattern->parameter(0);
auto pattern_dot = op::dot(parg, parg); auto pattern_dot = op::dot(parg, parg);
ASSERT_TRUE(pattern_dot->is_same_op_type(dot)); ASSERT_TRUE(pattern_dot->is_same_op_type(dot));
// TODO This passes because typeid is not behaving as documented. // TODO This passes because typeid is not behaving as documented.
...@@ -84,8 +73,4 @@ TEST(build_graph, node_comparison) ...@@ -84,8 +73,4 @@ TEST(build_graph, node_comparison)
} }
// Check argument inverses // Check argument inverses
TEST(build_graph, arg_inverse) TEST(build_graph, arg_inverse) {}
{
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment