Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
0c06b371
Commit
0c06b371
authored
Aug 29, 2017
by
Scott Cyphers
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Separate Parameter from Function
parent
2c30e819
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
127 additions
and
57 deletions
+127
-57
CMakeLists.txt
src/CMakeLists.txt
+1
-0
function.hpp
src/ngraph/function.hpp
+11
-41
ngraph.hpp
src/ngraph/ngraph.hpp
+1
-0
op.hpp
src/ngraph/op.hpp
+1
-0
parameter.hpp
src/ngraph/parameter.hpp
+52
-0
function.cpp
src/ops/function.cpp
+13
-16
parameter.cpp
src/ops/parameter.cpp
+48
-0
No files found.
src/CMakeLists.txt
View file @
0c06b371
...
...
@@ -20,6 +20,7 @@ set (SRC
log.cpp
ops/function.cpp
ops/op.cpp
ops/parameter.cpp
types/element_type.cpp
)
...
...
src/ngraph/function.hpp
View file @
0c06b371
...
...
@@ -16,47 +16,11 @@
#include "ngraph/node.hpp"
#include "ngraph/op.hpp"
#include "ngraph/parameter.hpp"
#include "ngraph/type.hpp"
namespace
ngraph
{
class
Function
;
/**
** One parameter of a function. Within the function's graph
** the parameter is a node that represents the argument in a call.
**/
class
Parameter
:
public
Node
{
public
:
using
ptr
=
std
::
shared_ptr
<
Parameter
>
;
Parameter
(
Function
&
function
,
size_t
index
);
std
::
string
description
()
const
override
{
return
"Parameter"
;
}
virtual
void
propagate_types
()
override
;
protected
:
Function
&
m_function
;
size_t
m_index
;
};
/**
** The result of a function. The ndoe addociated with the result
** supplies the return value when the function is called.
**/
class
Result
:
public
TypedValueMixin
{
public
:
using
ptr
=
std
::
shared_ptr
<
Result
>
;
Node
::
ptr
value
()
const
{
return
m_value
;
}
void
value
(
const
Node
::
ptr
&
value
)
{
m_value
=
value
;
}
protected
:
Node
::
ptr
m_value
;
};
/**
** A user-defined function.
...
...
@@ -64,17 +28,23 @@ namespace ngraph
class
Function
{
public
:
Function
(
size_t
n_
parameters
);
Function
(
const
Node
::
ptr
&
result
,
const
std
::
vector
<
std
::
shared_ptr
<
Parameter
>>&
parameters
);
Result
*
result
()
{
return
&
m_result
;
}
Node
::
ptr
result
()
{
return
m_result
;
}
Parameter
::
ptr
parameter
(
size_t
i
)
{
return
m_parameters
[
i
];
}
std
::
string
name
()
const
{
return
m_name
;
}
protected
:
std
::
vector
<
Parameter
::
ptr
>
m_parameters
;
Result
m_result
;
Node
::
ptr
m_result
;
std
::
vector
<
std
::
shared_ptr
<
ngraph
::
Parameter
>>
m_parameters
;
std
::
string
m_name
;
};
namespace
op
{
std
::
shared_ptr
<
Function
>
function
(
const
Node
::
ptr
&
result
,
const
std
::
initializer_list
<
std
::
shared_ptr
<
Parameter
>>&
parameters
);
std
::
shared_ptr
<
Function
>
function
(
const
Node
::
ptr
&
result
,
const
std
::
vector
<
std
::
shared_ptr
<
Parameter
>>&
parameters
);
}
}
src/ngraph/ngraph.hpp
View file @
0c06b371
...
...
@@ -23,5 +23,6 @@
#include "ngraph/function.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op.hpp"
#include "ngraph/parameter.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/type.hpp"
src/ngraph/op.hpp
View file @
0c06b371
...
...
@@ -17,6 +17,7 @@
#include <memory>
#include "ngraph/node.hpp"
#include "ngraph/parameter.hpp"
#include "ngraph/type.hpp"
namespace
ngraph
...
...
src/ngraph/parameter.hpp
0 → 100644
View file @
0c06b371
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/node.hpp"
#include "ngraph/type.hpp"
namespace
ngraph
{
class
Function
;
/**
** One parameter of a function. Within the function's graph
** the parameter is a node that represents the argument in a call.
**/
class
Parameter
:
public
Node
{
friend
class
Function
;
protected
:
void
assign_function
(
Function
*
function
,
size_t
index
);
public
:
Parameter
(
const
ValueType
::
ptr
&
value_type
);
std
::
string
description
()
const
override
{
return
"Parameter"
;
}
virtual
void
propagate_types
()
override
;
protected
:
Function
*
m_function
;
size_t
m_index
;
};
namespace
op
{
std
::
shared_ptr
<
ngraph
::
Parameter
>
parameter
(
const
ValueType
::
ptr
&
value_type
=
nullptr
);
std
::
shared_ptr
<
ngraph
::
Parameter
>
parameter
(
const
ngraph
::
element
::
Type
element_type
,
const
Shape
&
shape
);
}
}
src/ops/function.cpp
View file @
0c06b371
...
...
@@ -17,27 +17,24 @@
using
namespace
std
;
using
namespace
ngraph
;
Parameter
::
Parameter
(
Function
&
function
,
size_t
index
)
:
Node
({}
)
,
m_
function
(
function
)
,
m_
index
(
index
)
Function
::
Function
(
const
Node
::
ptr
&
result
,
const
std
::
vector
<
std
::
shared_ptr
<
ngraph
::
Parameter
>>&
parameters
)
:
m_result
(
result
)
,
m_
parameters
(
parameters
)
,
m_
name
(
"Function"
)
{
size_t
i
=
0
;
for
(
auto
parameter
:
parameters
)
{
parameter
->
assign_function
(
this
,
i
++
);
}
}
void
Parameter
::
propagate_types
(
)
shared_ptr
<
Function
>
ngraph
::
op
::
function
(
const
Node
::
ptr
&
result
,
const
initializer_list
<
shared_ptr
<
Parameter
>>&
parameters
)
{
if
(
m_type
==
nullptr
)
{
throw
ngraph_error
{
"Unitialized parameter"
};
}
return
make_shared
<
Function
>
(
result
,
parameters
);
}
Function
::
Function
(
size_t
n_parameters
)
:
m_parameters
(
n_parameters
)
,
m_name
(
"Function"
)
shared_ptr
<
Function
>
ngraph
::
op
::
function
(
const
Node
::
ptr
&
result
,
const
vector
<
shared_ptr
<
Parameter
>>&
parameters
)
{
for
(
int
i
=
0
;
i
<
n_parameters
;
i
++
)
{
m_parameters
[
i
]
=
std
::
make_shared
<
Parameter
>
(
*
this
,
i
);
}
return
make_shared
<
Function
>
(
result
,
parameters
);
}
src/ops/parameter.cpp
0 → 100644
View file @
0c06b371
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/ngraph.hpp"
using
namespace
std
;
using
namespace
ngraph
;
Parameter
::
Parameter
(
const
ValueType
::
ptr
&
value_type
)
:
Node
({},
value_type
)
,
m_function
(
nullptr
)
,
m_index
(
0
)
{
}
void
Parameter
::
assign_function
(
Function
*
function
,
size_t
index
)
{
if
(
nullptr
!=
m_function
){
throw
ngraph_error
(
"Re-assigning function to a parameter."
);
}
m_function
=
function
;
m_index
=
index
;
}
void
Parameter
::
propagate_types
()
{
}
shared_ptr
<
Parameter
>
ngraph
::
op
::
parameter
(
const
ValueType
::
ptr
&
value_type
)
{
return
make_shared
<
Parameter
>
(
value_type
);
}
shared_ptr
<
Parameter
>
ngraph
::
op
::
parameter
(
const
ngraph
::
element
::
Type
element_type
,
const
Shape
&
shape
)
{
return
make_shared
<
Parameter
>
(
make_shared
<
TensorViewType
>
(
element_type
,
shape
));
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment