Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
6d6e923b
Commit
6d6e923b
authored
Sep 01, 2017
by
Scott Cyphers
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Try universal node creator
parent
42f599c6
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
45 additions
and
48 deletions
+45
-48
common.hpp
src/ngraph/common.hpp
+15
-2
broadcast.hpp
src/ngraph/ops/broadcast.hpp
+3
-5
constant.hpp
src/ngraph/ops/constant.hpp
+0
-8
parameter.hpp
src/ngraph/ops/parameter.hpp
+1
-0
broadcast.cpp
src/ops/broadcast.cpp
+1
-1
parameter.cpp
src/ops/parameter.cpp
+5
-0
build_graph.cpp
test/build_graph.cpp
+20
-32
No files found.
src/ngraph/common.hpp
View file @
6d6e923b
...
...
@@ -23,12 +23,25 @@ namespace ngraph
{
class
Node
;
class
Parameter
;
class
ValueType
;
template
<
typename
T
,
typename
...
A
>
std
::
shared_ptr
<
T
>
node
(
A
&&
...
args
)
{
return
std
::
make_shared
<
T
>
(
args
...);
}
/// Zero or more value types
using
ValueTypes
=
std
::
vector
<
std
::
shared_ptr
<
ValueType
>>
;
/// Zero or more nodes
using
Nodes
=
std
::
vector
<
std
::
shared_ptr
<
Node
>>
;
/// A sequence of axes
using
AxisList
=
std
::
vector
<
size_t
>
;
/// A set of
indic
es, for example, reduction axes
using
Index
Set
=
std
::
set
<
size_t
>
;
/// A set of
ax
es, for example, reduction axes
using
Axis
Set
=
std
::
set
<
size_t
>
;
/// A list of parameters
using
Parameters
=
std
::
vector
<
std
::
shared_ptr
<
Parameter
>>
;
...
...
src/ngraph/ops/broadcast.hpp
View file @
6d6e923b
...
...
@@ -19,15 +19,13 @@ namespace ngraph
class
BroadcastOp
:
public
BuiltinOp
{
public
:
using
Axes
=
std
::
vector
<
size_t
>
;
/**
** /param arg The tensor view to be broadcast.
** /param shape The shape of the result
** /param broadcast_axes The axis positions (0-based) in the result that are being broadcast.
** the remaining axes in shape must be the same as the shape of arg.
**/
BroadcastOp
(
const
Node
::
ptr
&
arg
,
const
Shape
&
shape
,
const
Axes
&
broadcast_axes
)
BroadcastOp
(
const
Node
::
ptr
&
arg
,
const
Shape
&
shape
,
AxisSet
&
broadcast_axes
)
:
BuiltinOp
({
arg
})
,
m_shape
(
shape
)
,
m_broadcast_axes
(
broadcast_axes
)
...
...
@@ -39,13 +37,13 @@ namespace ngraph
protected
:
Shape
m_shape
;
Ax
es
m_broadcast_axes
;
Ax
isSet
m_broadcast_axes
;
};
namespace
op
{
Node
::
ptr
broadcast
(
const
Node
::
ptr
&
tensor
,
const
Shape
&
shape
,
const
BroadcastOp
::
Axes
&&
broadcast_axes
);
AxisSet
&&
broadcast_axes
);
}
}
src/ngraph/ops/constant.hpp
View file @
6d6e923b
...
...
@@ -59,14 +59,6 @@ namespace ngraph
typename
T
::
type
value
()
const
{
return
m_value
;
}
// Make a constant from any value that can be converted to the C++ type we use
// to represent the values.
template
<
typename
U
>
static
std
::
shared_ptr
<
ScalarConstant
<
T
>>
make
(
U
value
)
{
return
std
::
make_shared
<
ScalarConstant
<
T
>>
(
value
);
}
protected
:
typename
T
::
type
m_value
;
};
...
...
src/ngraph/ops/parameter.hpp
View file @
6d6e923b
...
...
@@ -37,6 +37,7 @@ namespace ngraph
public
:
Parameter
(
const
ValueType
::
ptr
&
value_type
);
Parameter
(
const
ngraph
::
element
::
Type
element_type
,
const
Shape
&
shape
);
std
::
string
description
()
const
override
{
return
"Parameter"
;
}
virtual
void
propagate_types
()
override
;
...
...
src/ops/broadcast.cpp
View file @
6d6e923b
...
...
@@ -25,7 +25,7 @@ using namespace ngraph;
**/
Node
::
ptr
ngraph
::
op
::
broadcast
(
const
Node
::
ptr
&
tensor
,
const
Shape
&
shape
,
const
BroadcastOp
::
Axes
&&
broadcast_axes
)
AxisSet
&&
broadcast_axes
)
{
return
make_shared
<
BroadcastOp
>
(
tensor
,
shape
,
broadcast_axes
);
}
...
...
src/ops/parameter.cpp
View file @
6d6e923b
...
...
@@ -26,6 +26,11 @@ Parameter::Parameter(const ValueType::ptr& value_type)
{
}
Parameter
::
Parameter
(
const
ngraph
::
element
::
Type
element_type
,
const
Shape
&
shape
)
:
Parameter
(
make_shared
<
TensorViewType
>
(
element_type
,
shape
))
{
}
void
Parameter
::
assign_function
(
Function
*
function
,
size_t
index
)
{
if
(
nullptr
!=
m_function
)
...
...
test/build_graph.cpp
View file @
6d6e923b
...
...
@@ -20,29 +20,16 @@
using
namespace
std
;
using
namespace
ngraph
;
template
<
typename
T
,
typename
...
A
>
std
::
shared_ptr
<
T
>
myfun
(
A
&&
...
args
)
{
return
std
::
make_shared
<
T
>
(
args
...);
}
template
<>
std
::
shared_ptr
<
Parameter
>
myfun
<
Parameter
>
(
ngraph
::
element
::
Type
&&
element_type
,
Shape
&&
shape
)
{
return
make_shared
<
Parameter
>
(
make_shared
<
TensorViewType
>
(
element_type
,
shape
));
}
TEST
(
build_graph
,
build_simple
)
{
// Function with 4 parameters
auto
arg0
=
op
::
parameter
(
element
::
Float
::
element_type
(),
Shape
{
7
,
3
});
auto
arg1
=
op
::
parameter
(
element
::
Float
::
element_type
(),
Shape
{
3
});
auto
arg2
=
op
::
parameter
(
element
::
Float
::
element_type
(),
Shape
{
32
,
7
});
auto
arg3
=
op
::
parameter
(
element
::
Float
::
element_type
(),
Shape
{
32
,
7
});
auto
broadcast_1
=
op
::
broadcast
(
arg3
,
Shape
{
10
,
32
,
7
},
BroadcastOp
::
Axes
{
0
});
auto
b1
=
myfun
<
BroadcastOp
>
(
arg3
,
Shape
{
10
,
32
,
7
},
BroadcastOp
::
Axes
{
0
});
auto
dot
=
op
::
dot
(
arg2
,
arg0
);
auto
arg0
=
node
<
Parameter
>
(
element
::
Float
::
element_type
(),
Shape
{
7
,
3
});
auto
arg1
=
node
<
Parameter
>
(
element
::
Float
::
element_type
(),
Shape
{
3
});
auto
arg2
=
node
<
Parameter
>
(
element
::
Float
::
element_type
(),
Shape
{
32
,
7
});
auto
arg3
=
node
<
Parameter
>
(
element
::
Float
::
element_type
(),
Shape
{
32
,
7
});
auto
broadcast_1
=
node
<
BroadcastOp
>
(
arg3
,
Shape
{
10
,
32
,
7
},
AxisSet
{
0
});
auto
b1
=
node
<
BroadcastOp
>
(
arg3
,
Shape
{
10
,
32
,
7
},
AxisSet
{
0
});
auto
dot
=
node
<
DotOp
>
(
arg2
,
arg0
);
ASSERT_EQ
(
dot
->
arguments
()[
0
],
arg2
);
ASSERT_EQ
(
dot
->
arguments
()[
1
],
arg0
);
...
...
@@ -55,14 +42,14 @@ TEST(build_graph, build_simple)
TEST
(
build_graph
,
as_type
)
{
// Check upcasting a ValueType::ptr that is a TensorViewType to a TensorViewType and Tuple.
ValueType
::
ptr
tv_vt
=
make_shared
<
TensorViewType
>
(
element
::
Float
::
element_type
(),
Shape
{
2
,
3
,
5
});
auto
tv_tv
=
dynamic_pointer_cast
<
TensorViewType
>
(
tv_vt
);
auto
tv_vt
=
make_shared
<
TensorViewType
>
(
element
::
Float
::
element_type
(),
Shape
{
2
,
3
,
5
});
auto
tv_tv
=
dynamic_pointer_cast
<
TensorViewType
>
(
tv_vt
);
ASSERT_EQ
(
tv_vt
,
tv_tv
);
auto
tv_tp
=
dynamic_pointer_cast
<
TupleType
>
(
tv_vt
);
ASSERT_EQ
(
nullptr
,
tv_tp
);
// Check upcasting a ValueType::ptr that is a TupleType to a TensorViewType and Tuple.
ValueType
::
ptr
tp_vt
=
make_shared
<
TupleType
>
(
vector
<
ValueType
::
ptr
>
{
tv_vt
,
tv_vt
});
auto
tp_vt
=
make_shared
<
TupleType
>
(
ValueTypes
{
tv_vt
,
tv_vt
});
auto
tp_tv
=
dynamic_pointer_cast
<
TensorViewType
>
(
tp_vt
);
ASSERT_EQ
(
nullptr
,
tp_tv
);
auto
tp_tp
=
dynamic_pointer_cast
<
TupleType
>
(
tp_vt
);
...
...
@@ -72,15 +59,15 @@ TEST(build_graph, as_type)
// Check node comparisons
TEST
(
build_graph
,
node_comparison
)
{
auto
arg0
=
op
::
parameter
(
element
::
Float
::
element_type
(),
{
32
,
3
});
auto
arg1
=
op
::
parameter
(
element
::
Float
::
element_type
(),
{
3
});
auto
arg2
=
op
::
parameter
(
element
::
Float
::
element_type
(),
{
32
});
auto
arg0
=
node
<
Parameter
>
(
element
::
Float
::
element_type
(),
Shape
{
32
,
3
});
auto
arg1
=
node
<
Parameter
>
(
element
::
Float
::
element_type
(),
Shape
{
3
});
auto
arg2
=
node
<
Parameter
>
(
element
::
Float
::
element_type
(),
Shape
{
32
});
auto
dot
=
op
::
dot
(
arg0
,
arg1
);
auto
add
=
op
::
add
(
dot
,
arg2
);
auto
parg
=
op
::
parameter
(
element
::
Float
::
element_type
(),
{});
auto
pattern_dot
=
op
::
dot
(
parg
,
parg
);
auto
parg
=
node
<
Parameter
>
(
element
::
Float
::
element_type
(),
Shape
{});
auto
pattern_dot
=
node
<
DotOp
>
(
parg
,
parg
);
ASSERT_TRUE
(
pattern_dot
->
is_same_op_type
(
dot
));
// TODO This passes because typeid is not behaving as documented.
// Need to figure out what's wrong.
...
...
@@ -90,20 +77,21 @@ TEST(build_graph, node_comparison)
TEST
(
build_graph
,
literal
)
{
// float scalar from a float
auto
float0
=
FloatScalarConstant
::
make
(
3.0
);
//auto float0 = FloatScalarConstant::make(3.0);
auto
float0
=
node
<
FloatScalarConstant
>
(
3.0
);
auto
float_scalar_type
=
make_shared
<
TensorViewType
>
(
element
::
Float
::
element_type
(),
Shape
{});
ASSERT_EQ
(
float0
->
value
(),
3.0
);
ASSERT_EQ
(
*
float0
->
value_type
(),
float_scalar_type
);
auto
d
=
op
::
dot
(
float0
,
float0
);
auto
d
=
node
<
DotOp
>
(
float0
,
float0
);
ASSERT_EQ
(
d
->
arguments
().
at
(
0
),
float0
);
ASSERT_EQ
(
d
->
arguments
().
at
(
1
),
float0
);
// float scalar from an int
auto
float1
=
FloatScalarConstant
::
make
(
3
);
auto
float1
=
node
<
FloatScalarConstant
>
(
3
);
ASSERT_EQ
(
float1
->
value
(),
3
);
ASSERT_EQ
(
*
float1
->
value_type
(),
float_scalar_type
);
auto
int32_0
=
Int32ScalarConstant
::
make
(
3.0
);
auto
int32_0
=
node
<
Int32ScalarConstant
>
(
3.0
);
auto
int32_scalar_type
=
make_shared
<
TensorViewType
>
(
element
::
Int32
::
element_type
(),
Shape
{});
ASSERT_EQ
(
int32_0
->
value
(),
3
);
ASSERT_EQ
(
*
int32_0
->
value_type
(),
int32_scalar_type
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment