Commit 0c06b371 authored by Scott Cyphers's avatar Scott Cyphers

Separate Parameter from Function

parent 2c30e819
......@@ -20,6 +20,7 @@ set (SRC
log.cpp
ops/function.cpp
ops/op.cpp
ops/parameter.cpp
types/element_type.cpp
)
......
......@@ -16,47 +16,11 @@
#include "ngraph/node.hpp"
#include "ngraph/op.hpp"
#include "ngraph/parameter.hpp"
#include "ngraph/type.hpp"
namespace ngraph
{
class Function;
/**
** One parameter of a function. Within the function's graph
** the parameter is a node that represents the argument in a call.
**/
class Parameter : public Node
{
public:
using ptr = std::shared_ptr<Parameter>;
Parameter(Function& function, size_t index);
std::string description() const override { return "Parameter"; }
virtual void propagate_types() override;
protected:
Function& m_function;
size_t m_index;
};
/**
** The result of a function. The ndoe addociated with the result
** supplies the return value when the function is called.
**/
class Result : public TypedValueMixin
{
public:
using ptr = std::shared_ptr<Result>;
Node::ptr value() const { return m_value; }
void value(const Node::ptr& value) { m_value = value; }
protected:
Node::ptr m_value;
};
/**
** A user-defined function.
......@@ -64,17 +28,23 @@ namespace ngraph
class Function
{
public:
Function(size_t n_parameters);
Function(const Node::ptr& result, const std::vector<std::shared_ptr<Parameter>>& parameters);
Result* result() { return &m_result; }
Node::ptr result() { return m_result; }
Parameter::ptr parameter(size_t i) { return m_parameters[i]; }
std::string name() const { return m_name; }
protected:
std::vector<Parameter::ptr> m_parameters;
Result m_result;
Node::ptr m_result;
std::vector<std::shared_ptr<ngraph::Parameter>> m_parameters;
std::string m_name;
};
namespace op
{
std::shared_ptr<Function> function(const Node::ptr& result, const std::initializer_list<std::shared_ptr<Parameter>>& parameters);
std::shared_ptr<Function> function(const Node::ptr& result, const std::vector<std::shared_ptr<Parameter>>& parameters);
}
}
......@@ -23,5 +23,6 @@
#include "ngraph/function.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op.hpp"
#include "ngraph/parameter.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/type.hpp"
......@@ -17,6 +17,7 @@
#include <memory>
#include "ngraph/node.hpp"
#include "ngraph/parameter.hpp"
#include "ngraph/type.hpp"
namespace ngraph
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/node.hpp"
#include "ngraph/type.hpp"
namespace ngraph
{
class Function;
/**
** One parameter of a function. Within the function's graph
** the parameter is a node that represents the argument in a call.
**/
class Parameter : public Node
{
friend class Function;
protected:
void assign_function(Function* function, size_t index);
public:
Parameter(const ValueType::ptr& value_type);
std::string description() const override { return "Parameter"; }
virtual void propagate_types() override;
protected:
Function* m_function;
size_t m_index;
};
namespace op
{
std::shared_ptr<ngraph::Parameter> parameter(const ValueType::ptr& value_type=nullptr);
std::shared_ptr<ngraph::Parameter> parameter(const ngraph::element::Type element_type, const Shape& shape);
}
}
......@@ -17,27 +17,24 @@
using namespace std;
using namespace ngraph;
Parameter::Parameter(Function& function, size_t index)
: Node({})
, m_function(function)
, m_index(index)
Function::Function(const Node::ptr& result, const std::vector<std::shared_ptr<ngraph::Parameter>>& parameters)
: m_result(result)
, m_parameters(parameters)
, m_name("Function")
{
size_t i = 0;
for (auto parameter : parameters)
{
parameter->assign_function(this, i++);
}
}
void Parameter::propagate_types()
shared_ptr<Function> ngraph::op::function(const Node::ptr& result, const initializer_list<shared_ptr<Parameter>>& parameters)
{
if (m_type == nullptr)
{
throw ngraph_error{"Unitialized parameter"};
}
return make_shared<Function>(result, parameters);
}
Function::Function(size_t n_parameters)
: m_parameters(n_parameters)
, m_name("Function")
shared_ptr<Function> ngraph::op::function(const Node::ptr& result, const vector<shared_ptr<Parameter>>& parameters)
{
for (int i = 0; i < n_parameters; i++)
{
m_parameters[i] = std::make_shared<Parameter>(*this, i);
}
return make_shared<Function>(result, parameters);
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/ngraph.hpp"
using namespace std;
using namespace ngraph;
Parameter::Parameter(const ValueType::ptr& value_type)
: Node({}, value_type)
, m_function(nullptr)
, m_index(0)
{
}
void Parameter::assign_function(Function* function, size_t index)
{
if (nullptr != m_function){
throw ngraph_error("Re-assigning function to a parameter.");
}
m_function = function;
m_index = index;
}
void Parameter::propagate_types()
{
}
shared_ptr<Parameter> ngraph::op::parameter(const ValueType::ptr& value_type)
{
return make_shared<Parameter>(value_type);
}
shared_ptr<Parameter> ngraph::op::parameter(const ngraph::element::Type element_type, const Shape& shape)
{
return make_shared<Parameter>(make_shared<TensorViewType>(element_type, shape));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment