Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
271fb025
Commit
271fb025
authored
Aug 23, 2017
by
Scott Cyphers
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Organize files, add method to get op from call.
parent
c1806e85
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
258 additions
and
134 deletions
+258
-134
CMakeLists.txt
src/CMakeLists.txt
+3
-3
element_type.hpp
src/ngraph/element_type.hpp
+4
-4
function.hpp
src/ngraph/function.hpp
+16
-29
ngraph.hpp
src/ngraph/ngraph.hpp
+10
-9
node.hpp
src/ngraph/node.hpp
+14
-18
op.hpp
src/ngraph/op.hpp
+10
-8
shape.hpp
src/ngraph/shape.hpp
+12
-37
type.hpp
src/ngraph/type.hpp
+137
-0
function.cpp
src/ops/function.cpp
+17
-1
op.cpp
src/ops/op.cpp
+13
-3
element_type.cpp
src/types/element_type.cpp
+1
-1
build_graph.cpp
test/build_graph.cpp
+20
-20
element_type.cpp
test/element_type.cpp
+1
-1
No files found.
src/CMakeLists.txt
View file @
271fb025
...
@@ -15,12 +15,12 @@ get_filename_component( NGRAPH_INCLUDE_DIR . ABSOLUTE)
...
@@ -15,12 +15,12 @@ get_filename_component( NGRAPH_INCLUDE_DIR . ABSOLUTE)
set
(
NGRAPH_INCLUDE_DIR
"
${
NGRAPH_INCLUDE_DIR
}
"
PARENT_SCOPE
)
set
(
NGRAPH_INCLUDE_DIR
"
${
NGRAPH_INCLUDE_DIR
}
"
PARENT_SCOPE
)
set
(
SRC
set
(
SRC
element_type.cpp
tree.cpp
tree.cpp
util.cpp
util.cpp
log.cpp
log.cpp
values/function.cpp
ops/function.cpp
values/op.cpp
ops/op.cpp
types/element_type.cpp
)
)
# NOTE: We'd prefer to only have the .cpp files *in* the 'transformers' directory be compiled
# NOTE: We'd prefer to only have the .cpp files *in* the 'transformers' directory be compiled
...
...
src/element_type.hpp
→
src/
ngraph/
element_type.hpp
View file @
271fb025
...
@@ -43,10 +43,10 @@ public:
...
@@ -43,10 +43,10 @@ public:
private
:
private
:
static
std
::
map
<
std
::
string
,
ElementType
>
m_element_list
;
static
std
::
map
<
std
::
string
,
ElementType
>
m_element_list
;
size_t
m_bitwidth
;
size_t
m_bitwidth
;
bool
m_is_float
;
bool
m_is_float
;
bool
m_is_signed
;
bool
m_is_signed
;
const
std
::
string
m_cname
;
const
std
::
string
m_cname
;
};
};
extern
const
ngraph
::
ElementType
element_type_float
;
extern
const
ngraph
::
ElementType
element_type_float
;
...
...
src/
values
/function.hpp
→
src/
ngraph
/function.hpp
View file @
271fb025
...
@@ -14,9 +14,9 @@
...
@@ -14,9 +14,9 @@
#pragma once
#pragma once
#include "
values
/node.hpp"
#include "
ngraph
/node.hpp"
#include "
values
/op.hpp"
#include "
ngraph
/op.hpp"
#include "
values
/type.hpp"
#include "
ngraph
/type.hpp"
namespace
ngraph
namespace
ngraph
{
{
...
@@ -25,53 +25,39 @@ namespace ngraph
...
@@ -25,53 +25,39 @@ namespace ngraph
class
Parameter
:
public
Node
class
Parameter
:
public
Node
{
{
public
:
public
:
Parameter
(
Function
&
function
,
size_t
index
,
const
std
::
shared_ptr
<
ValueType
>&
type
)
using
ptr
=
std
::
shared_ptr
<
Parameter
>
;
:
Node
({},
type
)
,
m_function
(
function
)
Parameter
(
Function
&
function
,
size_t
index
);
,
m_index
(
index
)
{
}
protected
:
protected
:
Function
&
m_function
;
Function
&
m_function
;
size_t
m_index
;
size_t
m_index
;
};
};
class
Result
class
Result
:
public
TypedValueMixin
{
{
public
:
public
:
void
type
(
const
std
::
shared_ptr
<
ValueType
>&
t
)
{
m_type
=
t
;
}
using
ptr
=
std
::
shared_ptr
<
Result
>
;
void
type
(
const
ElementType
&
element_type
,
const
Shape
&
shape
)
{
m_type
=
std
::
make_shared
<
TensorViewType
>
(
element_type
,
shape
);
}
std
::
shared_ptr
<
ValueType
>
type
()
const
{
return
m_type
;
}
std
::
shared_ptr
<
Node
>
value
()
const
{
return
m_value
;
}
Node
::
ptr
value
()
const
{
return
m_value
;
}
void
value
(
const
std
::
shared_ptr
<
Node
>
&
value
)
{
m_value
=
value
;
}
void
value
(
const
Node
::
ptr
&
value
)
{
m_value
=
value
;
}
protected
:
protected
:
std
::
shared_ptr
<
ValueType
>
m_type
;
Node
::
ptr
m_value
;
std
::
shared_ptr
<
Node
>
m_value
;
};
};
class
Function
class
Function
:
public
Op
{
{
public
:
public
:
Function
(
size_t
n_parameters
)
Function
(
size_t
n_parameters
);
:
m_parameters
(
n_parameters
)
{
}
Result
*
result
()
{
return
&
m_result
;
}
Result
*
result
()
{
return
&
m_result
;
}
std
::
shared_ptr
<
Parameter
>
parameter
(
size_t
i
)
{
return
m_parameters
[
i
];
}
std
::
shared_ptr
<
Parameter
>
parameter
(
size_t
i
)
{
return
m_parameters
[
i
];
}
protected
:
protected
:
std
::
vector
<
std
::
shared_ptr
<
Parameter
>
>
m_parameters
;
std
::
vector
<
Parameter
::
ptr
>
m_parameters
;
Result
m_result
;
Result
m_result
;
};
};
}
// end namespace ngraph
}
// end namespace ngraph
\ No newline at end of file
src/
values/descriptor
.hpp
→
src/
ngraph/ngraph
.hpp
View file @
271fb025
...
@@ -12,14 +12,15 @@
...
@@ -12,14 +12,15 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
// ----------------------------------------------------------------------------
#pragma once
//
// The public API for ngraph++
#include <algorithm>
//
#include <memory>
#include <vector>
#
include "values/type.hpp"
#
pragma once
namespace
ngraph
#include "ngraph/element_type.hpp"
{
#include "ngraph/function.hpp"
}
#include "ngraph/node.hpp"
#include "ngraph/op.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/type.hpp"
src/
values
/node.hpp
→
src/
ngraph
/node.hpp
View file @
271fb025
...
@@ -16,41 +16,37 @@
...
@@ -16,41 +16,37 @@
#include <vector>
#include <vector>
#include "
values
/type.hpp"
#include "
ngraph
/type.hpp"
namespace
ngraph
namespace
ngraph
{
{
class
Node
class
Op
;
class
Node
:
public
TypedValueMixin
{
{
public
:
public
:
Node
(
const
std
::
vector
<
std
::
shared_ptr
<
Node
>>&
arguments
,
using
ptr
=
std
::
shared_ptr
<
Node
>
;
std
::
shared_ptr
<
ValueType
>
type
=
0
)
Node
(
const
std
::
vector
<
Node
::
ptr
>&
arguments
,
ValueType
::
ptr
type
=
0
)
:
m_arguments
(
arguments
)
:
m_arguments
(
arguments
)
,
m_type
(
type
)
,
TypedValueMixin
(
type
)
{
{
}
}
virtual
~
Node
()
{}
virtual
~
Node
()
{}
virtual
std
::
vector
<
std
::
shared_ptr
<
Node
>>
dependents
()
{
return
m_arguments
;
}
virtual
std
::
vector
<
Node
::
ptr
>
dependents
()
{
return
m_arguments
;
}
void
type
(
const
std
::
shared_ptr
<
ValueType
>&
t
)
{
m_type
=
t
;
}
void
type
(
const
ElementType
&
element_type
,
const
Shape
&
shape
)
{
m_type
=
std
::
make_shared
<
TensorViewType
>
(
element_type
,
shape
);
}
std
::
shared_ptr
<
ValueType
>
type
()
const
{
return
m_type
;
}
protected
:
protected
:
std
::
vector
<
std
::
shared_ptr
<
Node
>>
m_arguments
;
std
::
vector
<
Node
::
ptr
>
m_arguments
;
std
::
shared_ptr
<
ValueType
>
m_type
;
};
};
class
Call
:
public
Node
class
Call
:
public
Node
{
{
public
:
virtual
Op
&
op
()
const
=
0
;
protected
:
protected
:
Call
(
const
std
::
vector
<
std
::
shared_ptr
<
Node
>
>&
arguments
)
Call
(
const
std
::
vector
<
Node
::
ptr
>&
arguments
)
:
Node
(
arguments
,
0
)
:
Node
(
arguments
,
0
)
{
{
}
}
...
...
src/
values
/op.hpp
→
src/
ngraph
/op.hpp
View file @
271fb025
...
@@ -16,9 +16,8 @@
...
@@ -16,9 +16,8 @@
#include <memory>
#include <memory>
#include "values/descriptor.hpp"
#include "ngraph/node.hpp"
#include "values/node.hpp"
#include "ngraph/type.hpp"
#include "values/type.hpp"
namespace
ngraph
namespace
ngraph
{
{
...
@@ -33,18 +32,20 @@ namespace ngraph
...
@@ -33,18 +32,20 @@ namespace ngraph
friend
class
Broadcast
;
friend
class
Broadcast
;
public
:
public
:
BroadcastCall
(
const
std
::
shared_ptr
<
Node
>
&
arg
,
size_t
axis
)
BroadcastCall
(
const
Node
::
ptr
&
arg
,
size_t
axis
)
:
Call
({
arg
})
:
Call
({
arg
})
,
m_axis
(
axis
)
,
m_axis
(
axis
)
{
{
}
}
Op
&
op
()
const
override
;
protected
:
protected
:
size_t
m_axis
;
size_t
m_axis
;
};
};
public
:
public
:
std
::
shared_ptr
<
BroadcastCall
>
operator
()(
const
std
::
shared_ptr
<
Node
>
&
tensor
,
size_t
axis
)
std
::
shared_ptr
<
BroadcastCall
>
operator
()(
const
Node
::
ptr
&
tensor
,
size_t
axis
)
{
{
return
std
::
make_shared
<
BroadcastCall
>
(
tensor
,
axis
);
return
std
::
make_shared
<
BroadcastCall
>
(
tensor
,
axis
);
}
}
...
@@ -62,15 +63,16 @@ namespace ngraph
...
@@ -62,15 +63,16 @@ namespace ngraph
friend
class
Dot
;
friend
class
Dot
;
public
:
public
:
DotCall
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>
&
arg1
)
DotCall
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
Node
::
ptr
&
arg1
)
:
Call
({
arg0
,
arg1
})
:
Call
({
arg0
,
arg1
})
{
{
}
}
Op
&
op
()
const
override
;
};
};
public
:
public
:
std
::
shared_ptr
<
DotCall
>
operator
()(
const
std
::
shared_ptr
<
Node
>&
arg0
,
std
::
shared_ptr
<
DotCall
>
operator
()(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
const
std
::
shared_ptr
<
Node
>&
arg1
)
{
{
return
std
::
make_shared
<
DotCall
>
(
arg0
,
arg1
);
return
std
::
make_shared
<
DotCall
>
(
arg0
,
arg1
);
}
}
...
...
src/
values/ty
pe.hpp
→
src/
ngraph/sha
pe.hpp
View file @
271fb025
...
@@ -14,56 +14,30 @@
...
@@ -14,56 +14,30 @@
#pragma once
#pragma once
#include <memory>
#include <vector>
#include <vector>
#include "element_type.hpp"
namespace
ngraph
namespace
ngraph
{
{
/**
** Holds the shape of a tensor view.
**/
class
Shape
class
Shape
{
{
public
:
public
:
/**
** \param sizes A sequence of sizes.
**/
Shape
(
const
std
::
initializer_list
<
size_t
>&
sizes
)
Shape
(
const
std
::
initializer_list
<
size_t
>&
sizes
)
:
m_sizes
(
sizes
)
:
m_sizes
(
sizes
)
{
{
}
}
protected
:
/**
std
::
vector
<
size_t
>
m_sizes
;
** Conversion to a vector of sizes.
};
**/
operator
const
std
::
vector
<
size_t
>&
()
const
{
return
m_sizes
;
}
// ValueType is
// TensorViewType
// | TupleType(ValueType[])
class
ValueType
{
};
class
TensorViewType
:
public
ValueType
{
public
:
TensorViewType
(
const
ElementType
&
element_type
,
const
Shape
&
shape
)
:
m_element_type
(
element_type
)
,
m_shape
(
shape
)
{
}
protected
:
TensorViewType
(
const
TensorViewType
&
)
=
delete
;
const
ElementType
&
m_element_type
;
Shape
m_shape
;
};
class
TupleType
:
public
ValueType
{
public
:
TupleType
(
const
std
::
vector
<
std
::
shared_ptr
<
ValueType
>>&
element_types
)
:
m_element_types
(
element_types
)
{
}
protected
:
protected
:
std
::
vector
<
s
td
::
shared_ptr
<
ValueType
>>
m_element_typ
es
;
std
::
vector
<
s
ize_t
>
m_siz
es
;
};
};
}
}
\ No newline at end of file
src/ngraph/type.hpp
0 → 100644
View file @
271fb025
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <memory>
#include <vector>
#include "ngraph/element_type.hpp"
#include "ngraph/shape.hpp"
namespace
ngraph
{
/**
** ValueType is
** TensorViewType
** | TupleType(ValueType[])
**/
class
ValueType
{
public
:
/**
** Preferred handle
**/
using
ptr
=
std
::
shared_ptr
<
ValueType
>
;
};
/**
** Describes a tensor view; an element type and a shape.
**/
class
TensorViewType
:
public
ValueType
{
public
:
/**
** Preferred handle
**/
using
ptr
=
std
::
shared_ptr
<
TensorViewType
>
;
/**
** /param element_type The type of the tensor elements.
** /param shape The shape of the tensor.
**/
TensorViewType
(
const
ElementType
&
element_type
,
const
Shape
&
shape
)
:
m_element_type
(
element_type
)
,
m_shape
(
shape
)
{
}
protected
:
const
ElementType
&
m_element_type
;
Shape
m_shape
;
};
/**
** Describes a tuple of values; a vector of types
**/
class
TupleType
:
public
ValueType
{
public
:
/**
** The preferred handle
**/
using
ptr
=
std
::
shared_ptr
<
ValueType
>
;
/**
** Construct empty tuple and add value types later.
**/
TupleType
()
{}
/**
** /param element_types A vector of types for the tuple elements
**/
TupleType
(
const
std
::
vector
<
ValueType
::
ptr
>&
element_types
)
:
m_element_types
(
element_types
)
{
}
const
std
::
vector
<
ValueType
::
ptr
>
element_types
()
const
{
return
m_element_types
;
}
std
::
vector
<
ValueType
::
ptr
>
element_types
()
{
return
m_element_types
;
}
protected
:
std
::
vector
<
ValueType
::
ptr
>
m_element_types
;
};
/**
** Mixin for objects with type information
**/
class
TypedValueMixin
{
public
:
TypedValueMixin
(
const
ValueType
::
ptr
&
type
=
0
)
:
m_type
(
type
)
{
}
/**
** Set the type
** /param type The new type
**/
void
type
(
const
ValueType
::
ptr
&
type
)
{
m_type
=
type
;
}
/**
** Set the type to be a tensor view type
** /param element_type The type of the tensor elements
** /param shape The shape of the view
**/
void
type
(
const
ElementType
&
element_type
,
const
Shape
&
shape
)
{
m_type
=
TensorViewType
::
ptr
::
make_shared
(
element_type
,
shape
);
}
/**
** The type associated with this value.
**/
ValueType
::
ptr
type
()
{
return
m_type
;
}
/**
** The type associated with this value.
**/
const
ValueType
::
ptr
type
()
const
{
return
m_type
;
}
protected
:
ValueType
::
ptr
m_type
;
};
}
\ No newline at end of file
src/
value
s/function.cpp
→
src/
op
s/function.cpp
View file @
271fb025
...
@@ -12,7 +12,23 @@
...
@@ -12,7 +12,23 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
// ----------------------------------------------------------------------------
#include "
values/function
.hpp"
#include "
ngraph/ngraph
.hpp"
using
namespace
std
;
using
namespace
std
;
using
namespace
ngraph
;
using
namespace
ngraph
;
Parameter
::
Parameter
(
Function
&
function
,
size_t
index
)
:
Node
({})
,
m_function
(
function
)
,
m_index
(
index
)
{
}
Function
::
Function
(
size_t
n_parameters
)
:
m_parameters
(
n_parameters
)
{
for
(
int
i
=
0
;
i
<
n_parameters
;
i
++
)
{
m_parameters
[
i
]
=
Parameter
::
ptr
::
make_shared
(
*
this
,
i
);
}
}
src/
value
s/op.cpp
→
src/
op
s/op.cpp
View file @
271fb025
...
@@ -12,9 +12,20 @@
...
@@ -12,9 +12,20 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
// ----------------------------------------------------------------------------
#include "
values/op
.hpp"
#include "
ngraph/ngraph
.hpp"
using
namespace
ngraph
;
using
namespace
ngraph
;
Broadcast
ngraph
::
op
::
broadcast
{};
Broadcast
ngraph
::
op
::
broadcast
{};
Dot
ngraph
::
op
::
dot
{};
\ No newline at end of file
Op
&
ngraph
::
Broadcast
::
BroadcastCall
::
op
()
const
{
return
op
::
broadcast
;
}
Dot
ngraph
::
op
::
dot
{};
Op
&
ngraph
::
Dot
::
DotCall
::
op
()
const
{
return
op
::
dot
;
}
src/element_type.cpp
→
src/
types/
element_type.cpp
View file @
271fb025
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
#include <cassert>
#include <cassert>
#include <cmath>
#include <cmath>
#include "element_type.hpp"
#include "
ngraph/
element_type.hpp"
const
ngraph
::
ElementType
element_type_float
=
ngraph
::
ElementType
(
32
,
true
,
true
,
"float"
);
const
ngraph
::
ElementType
element_type_float
=
ngraph
::
ElementType
(
32
,
true
,
true
,
"float"
);
const
ngraph
::
ElementType
element_type_int8_t
=
ngraph
::
ElementType
(
8
,
false
,
true
,
"int8_t"
);
const
ngraph
::
ElementType
element_type_int8_t
=
ngraph
::
ElementType
(
8
,
false
,
true
,
"int8_t"
);
...
...
test/build_graph.cpp
View file @
271fb025
...
@@ -14,31 +14,31 @@
...
@@ -14,31 +14,31 @@
#include "gtest/gtest.h"
#include "gtest/gtest.h"
#include "values/type.hpp"
#include "ngraph/ngraph.hpp"
#include "values/function.hpp"
using
namespace
std
;
using
namespace
std
;
using
namespace
ngraph
;
using
namespace
ngraph
;
TEST
(
graph
,
build_simple
)
TEST
(
graph
,
build_simple
)
{
{
// // Function with 4 parameters
// Function with 4 parameters
// auto cluster_0 = make_shared<Function>(4);
auto
cluster_0
=
make_shared
<
Function
>
(
4
);
// cluster_0->result()->type(element_type_float, Shape {32, 3});
cluster_0
->
result
()
->
type
(
element_type_float
,
{
32
,
3
});
// cluster_0->parameter(0)->type(element_type_float, Shape {Shape {7, 3}});
cluster_0
->
parameter
(
0
)
->
type
(
element_type_float
,
{
7
,
3
});
// cluster_0->parameter(1)->type(element_type_float, Shape {Shape {3}});
cluster_0
->
parameter
(
1
)
->
type
(
element_type_float
,
{
3
});
// cluster_0->parameter(2)->type(element_type_float, Shape {Shape {32, 7}});
cluster_0
->
parameter
(
2
)
->
type
(
element_type_float
,
{
32
,
7
});
// cluster_0->parameter(3)->type(element_type_float, Shape {Shape {32, 7}});
cluster_0
->
parameter
(
3
)
->
type
(
element_type_float
,
{
32
,
7
});
// auto arg3 = cluster_0->parameter(3);
auto
arg3
=
cluster_0
->
parameter
(
3
);
// // call broadcast op on arg3, broadcasting on axis 1.
// call broadcast op on arg3, broadcasting on axis 1.
// auto broadcast_1 = op::broadcast(arg3, 1);
auto
broadcast_1
=
op
::
broadcast
(
arg3
,
1
);
// auto arg2 = cluster_0->parameter(2);
auto
arg2
=
cluster_0
->
parameter
(
2
);
// auto arg0 = cluster_0->parameter(0);
auto
arg0
=
cluster_0
->
parameter
(
0
);
// // call dot op
// call dot op
// auto dot = op::dot(arg2, arg0);
auto
dot
=
op
::
dot
(
arg2
,
arg0
);
// ASSERT_EQ(dot->dependents()[0], arg2);
ASSERT_EQ
(
dot
->
dependents
()[
0
],
arg2
);
// // Function returns tuple of dot and broadcast_1.
ASSERT_EQ
(
dot
->
dependents
()[
1
],
arg0
);
// cluster_0->result()->value(dot);
// Function returns tuple of dot and broadcast_1.
cluster_0
->
result
()
->
value
(
dot
);
//
ASSERT_EQ(cluster_0->result()->value(), dot);
ASSERT_EQ
(
cluster_0
->
result
()
->
value
(),
dot
);
}
}
test/element_type.cpp
View file @
271fb025
...
@@ -18,6 +18,6 @@
...
@@ -18,6 +18,6 @@
#include "gtest/gtest.h"
#include "gtest/gtest.h"
#include "element_type.hpp"
#include "
ngraph/
element_type.hpp"
using
namespace
ngraph
;
using
namespace
ngraph
;
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment