Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
19f16bc1
Commit
19f16bc1
authored
Aug 22, 2017
by
Scott Cyphers
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Finish basic graph building.
parent
d37359aa
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
150 additions
and
63 deletions
+150
-63
CMakeLists.txt
src/CMakeLists.txt
+1
-0
descriptor.hpp
src/values/descriptor.hpp
+1
-2
function.hpp
src/values/function.hpp
+25
-25
node.hpp
src/values/node.hpp
+23
-17
op.cpp
src/values/op.cpp
+21
-0
op.hpp
src/values/op.hpp
+57
-4
type.hpp
src/values/type.hpp
+13
-11
build_graph.cpp
test/build_graph.cpp
+9
-4
No files found.
src/CMakeLists.txt
View file @
19f16bc1
...
...
@@ -29,6 +29,7 @@ set (SRC
transformers/op_graph.cpp
values/function.cpp
values/op.cpp
)
# NOTE: We'd prefer to only have the .cpp files *in* the 'transformers' directory be compiled
...
...
src/values/descriptor.hpp
View file @
19f16bc1
...
...
@@ -20,7 +20,6 @@
#include "values/type.hpp"
namespace
ngraph
namespace
ngraph
{
}
src/values/function.hpp
View file @
19f16bc1
...
...
@@ -18,61 +18,60 @@
#include "values/op.hpp"
#include "values/type.hpp"
namespace
ngraph
namespace
ngraph
{
class
Function
;
class
Parameter
:
public
Node
{
public
:
Parameter
(
Function
&
function
,
size_t
index
,
const
std
::
shared_ptr
<
ValueType
>&
type
)
:
Node
(
type
)
,
m_function
(
function
)
,
m_index
(
index
)
{}
:
Node
({},
type
)
,
m_function
(
function
)
,
m_index
(
index
)
{
}
protected
:
Function
&
m_function
;
size_t
m_index
;
size_t
m_index
;
};
class
Result
{
class
Result
{
public
:
void
type
(
const
std
::
shared_ptr
<
ValueType
>&
t
){
m_type
=
t
;
}
void
type
(
const
std
::
shared_ptr
<
ValueType
>&
t
)
{
m_type
=
t
;
}
void
type
(
const
ElementType
&
element_type
,
const
Shape
&
shape
){
void
type
(
const
ElementType
&
element_type
,
const
Shape
&
shape
)
{
m_type
=
std
::
make_shared
<
TensorViewType
>
(
element_type
,
shape
);
}
std
::
shared_ptr
<
ValueType
>
type
()
const
{
return
m_type
;
}
std
::
shared_ptr
<
ValueType
>
type
()
const
{
return
m_type
;
}
std
::
shared_ptr
<
Node
>
value
()
const
{
return
m_value
;
}
void
value
(
const
std
::
shared_ptr
<
Node
>&
value
)
{
m_value
=
value
;
}
protected
:
std
::
shared_ptr
<
ValueType
>
m_type
;
std
::
shared_ptr
<
Node
>
m_value
;
};
class
Function
{
public
:
Function
(
size_t
n_parameters
)
:
m_parameters
(
n_parameters
)
{}
Result
*
result
(){
return
&
m_result
;
:
m_parameters
(
n_parameters
)
{
}
std
::
shared_ptr
<
Parameter
>
parameter
(
size_t
i
){
return
m_parameters
[
i
];
}
Result
*
result
()
{
return
&
m_result
;
}
std
::
shared_ptr
<
Parameter
>
parameter
(
size_t
i
)
{
return
m_parameters
[
i
];
}
protected
:
std
::
vector
<
std
::
shared_ptr
<
Parameter
>>
m_parameters
;
Result
m_result
;
Result
m_result
;
};
}
// end namespace ngraph
\ No newline at end of file
src/values/node.hpp
View file @
19f16bc1
...
...
@@ -20,34 +20,39 @@
namespace
ngraph
{
class
Node
{
public
:
Node
(
std
::
shared_ptr
<
ValueType
>
type
=
0
)
:
m_type
(
type
)
{}
virtual
~
Node
(){}
virtual
std
::
vector
<
std
::
shared_ptr
<
Node
>>
dependents
()
{
return
m_parameters
;
Node
(
const
std
::
vector
<
std
::
shared_ptr
<
Node
>>&
arguments
,
std
::
shared_ptr
<
ValueType
>
type
=
0
)
:
m_arguments
(
arguments
)
,
m_type
(
type
)
{
}
void
type
(
const
std
::
shared_ptr
<
ValueType
>&
t
){
m_type
=
t
;
}
virtual
~
Node
()
{}
virtual
std
::
vector
<
std
::
shared_ptr
<
Node
>>
dependents
()
{
return
m_arguments
;
}
void
type
(
const
ElementType
&
element_type
,
const
Shape
&
shape
){
void
type
(
const
std
::
shared_ptr
<
ValueType
>&
t
)
{
m_type
=
t
;
}
void
type
(
const
ElementType
&
element_type
,
const
Shape
&
shape
)
{
m_type
=
std
::
make_shared
<
TensorViewType
>
(
element_type
,
shape
);
}
std
::
shared_ptr
<
ValueType
>
type
()
const
{
return
m_type
;
}
std
::
shared_ptr
<
ValueType
>
type
()
const
{
return
m_type
;
}
protected
:
std
::
vector
<
std
::
shared_ptr
<
Node
>>
m_
parameter
s
;
std
::
shared_ptr
<
ValueType
>
m_type
;
std
::
vector
<
std
::
shared_ptr
<
Node
>>
m_
argument
s
;
std
::
shared_ptr
<
ValueType
>
m_type
;
};
class
Call
:
public
Node
{
protected
:
Call
(
const
std
::
vector
<
std
::
shared_ptr
<
Node
>>&
arguments
)
:
Node
(
arguments
,
0
)
{
}
};
}
\ No newline at end of file
src/values/op.cpp
0 → 100644
View file @
19f16bc1
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "values/op.hpp"
using
namespace
ngraph
;
Broadcast
ngraph
::
op
::
broadcast
{};
Dot
ngraph
::
op
::
dot
{};
\ No newline at end of file
src/values/op.hpp
View file @
19f16bc1
...
...
@@ -17,15 +17,67 @@
#include <memory>
#include "values/descriptor.hpp"
#include "values/node.hpp"
#include "values/type.hpp"
namespace
ngraph
{
class
Call
:
public
Node
class
Op
{
protected
:
std
::
vector
<
std
::
shared_ptr
<
Node
>>
m_args
;
};
class
Broadcast
:
public
Op
{
class
BroadcastCall
:
public
Call
{
friend
class
Broadcast
;
public
:
BroadcastCall
(
const
std
::
shared_ptr
<
Node
>&
arg
,
size_t
axis
)
:
Call
({
arg
})
,
m_axis
(
axis
)
{
}
protected
:
size_t
m_axis
;
};
public
:
std
::
shared_ptr
<
BroadcastCall
>
operator
()(
const
std
::
shared_ptr
<
Node
>&
tensor
,
size_t
axis
)
{
return
std
::
make_shared
<
BroadcastCall
>
(
tensor
,
axis
);
}
};
namespace
op
{
extern
Broadcast
broadcast
;
}
class
Dot
:
public
Op
{
class
DotCall
:
public
Call
{
friend
class
Dot
;
public
:
DotCall
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
:
Call
({
arg0
,
arg1
})
{
}
};
public
:
std
::
shared_ptr
<
DotCall
>
operator
()(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
)
{
return
std
::
make_shared
<
DotCall
>
(
arg0
,
arg1
);
}
};
namespace
op
{
extern
Dot
dot
;
}
}
\ No newline at end of file
src/values/type.hpp
View file @
19f16bc1
...
...
@@ -19,14 +19,15 @@
#include "element_type.hpp"
namespace
ngraph
{
namespace
ngraph
{
class
Shape
{
public
:
Shape
(
const
std
::
initializer_list
<
size_t
>&
sizes
)
:
m_sizes
(
sizes
)
{}
:
m_sizes
(
sizes
)
{
}
protected
:
std
::
vector
<
size_t
>
m_sizes
;
...
...
@@ -43,23 +44,24 @@ namespace ngraph {
{
public
:
TensorViewType
(
const
ElementType
&
element_type
,
const
Shape
&
shape
)
:
m_element_type
(
element_type
)
,
m_shape
(
shape
)
{}
:
m_element_type
(
element_type
)
,
m_shape
(
shape
)
{
}
protected
:
TensorViewType
(
const
TensorViewType
&
)
=
delete
;
const
ElementType
&
m_element_type
;
Shape
m_shape
;
Shape
m_shape
;
};
class
TupleType
:
public
ValueType
{
public
:
TupleType
(
const
std
::
vector
<
std
::
shared_ptr
<
ValueType
>>&
element_types
)
:
m_element_types
(
element_types
)
{}
:
m_element_types
(
element_types
)
{
}
protected
:
std
::
vector
<
std
::
shared_ptr
<
ValueType
>>
m_element_types
;
...
...
test/build_graph.cpp
View file @
19f16bc1
...
...
@@ -12,13 +12,15 @@
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "gtest/gtest.h"
#include "values/type.hpp"
#include "values/function.hpp"
using
namespace
std
;
using
namespace
ngraph
;
void
build_simple_graph
(
)
TEST
(
graph
,
build_simple
)
{
// Function with 4 parameters
auto
cluster_0
=
make_shared
<
Function
>
(
4
);
...
...
@@ -29,11 +31,14 @@ void build_simple_graph()
cluster_0
->
parameter
(
3
)
->
type
(
element_type_float
,
Shape
{
Shape
{
32
,
7
}});
auto
arg3
=
cluster_0
->
parameter
(
3
);
// call broadcast op on arg3, broadcasting on axis 1.
//
auto broadcast_1 = op::broadcast(arg3, 1);
auto
broadcast_1
=
op
::
broadcast
(
arg3
,
1
);
auto
arg2
=
cluster_0
->
parameter
(
2
);
auto
arg0
=
cluster_0
->
parameter
(
0
);
// call dot op
//auto dot = op::dot(arg2, arg0);
auto
dot
=
op
::
dot
(
arg2
,
arg0
);
ASSERT_EQ
(
dot
->
dependents
()[
0
],
arg2
);
// Function returns tuple of dot and broadcast_1.
//cluster_0.result->value(op::tuple(dot, broadcast_1));
cluster_0
->
result
()
->
value
(
dot
);
ASSERT_EQ
(
cluster_0
->
result
()
->
value
(),
dot
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment