Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
1b026daa
Commit
1b026daa
authored
Aug 28, 2017
by
Scott Cyphers
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Remove Op -- use typeid.
parent
5f8bf07e
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
74 additions
and
142 deletions
+74
-142
function.hpp
src/ngraph/function.hpp
+2
-2
op.hpp
src/ngraph/op.hpp
+72
-106
op.cpp
src/ops/op.cpp
+0
-34
No files found.
src/ngraph/function.hpp
View file @
1b026daa
...
...
@@ -61,7 +61,7 @@ namespace ngraph
/**
** A user-defined function.
**/
class
Function
:
public
Op
class
Function
{
public
:
Function
(
size_t
n_parameters
);
...
...
@@ -70,7 +70,7 @@ namespace ngraph
Parameter
::
ptr
parameter
(
size_t
i
)
{
return
m_parameters
[
i
];
}
std
::
string
name
()
const
override
{
return
m_name
;
}
std
::
string
name
()
const
{
return
m_name
;
}
protected
:
std
::
vector
<
Parameter
::
ptr
>
m_parameters
;
...
...
src/ngraph/op.hpp
View file @
1b026daa
...
...
@@ -64,18 +64,6 @@ namespace ngraph
//Node::ptr while();
}
/**
** Every instance of Op corresponds to a unique defined operation.
**/
class
Op
{
protected
:
virtual
~
Op
()
{}
public
:
virtual
std
::
string
name
()
const
=
0
;
};
/**
** Call nodes are nodes whose value is the result of some operation, the op,
** applied to its arguments. We use the op as a callable to construct the
...
...
@@ -84,53 +72,34 @@ namespace ngraph
class
Call
:
public
Node
{
public
:
const
Op
&
op
()
const
{
return
m_op
;
}
Call
(
const
Op
&
op
,
const
std
::
vector
<
Node
::
ptr
>&
arguments
)
Call
(
const
std
::
vector
<
Node
::
ptr
>&
arguments
)
:
Node
(
arguments
,
nullptr
)
,
m_op
(
op
)
{
}
virtual
std
::
string
description
()
const
override
{
return
m_op
.
name
();
}
protected
:
const
Op
&
m_op
;
};
/**
** There is exactly one instance of builtin op for each pre-defined operation. These
** are intended to be used when matching calls in different graphs; every FooCall
** will have the same op.
**/
class
BuiltinOp
:
public
Op
{
friend
class
Call
;
public
:
BuiltinOp
(
const
std
::
string
&
name
)
:
m_name
(
name
)
{
}
public
:
std
::
string
name
()
const
override
{
return
m_name
;
}
protected
:
std
::
string
m_name
;
/**
** Return true if this has the same implementing class as call. This
** will be used by the pattern matcher when comparing a pattern
** graph against the graph.
**/
bool
has_same_op
(
Call
&
call
)
{
return
typeid
(
this
)
==
typeid
(
&
call
);
}
virtual
std
::
string
description
()
const
override
{
return
"Call"
;
}
};
class
BuiltinCall
:
public
Call
{
public
:
virtual
std
::
string
description
()
const
override
{
return
"BuiltinCall"
;
}
/// Name of the builtin op, for debugging and logging.
virtual
std
::
string
op_name
()
const
=
0
;
// TODO: Implement for each op
virtual
void
propagate_types
()
override
{}
protected
:
BuiltinCall
(
const
Op
&
op
,
const
std
::
vector
<
Node
::
ptr
>&
args
)
:
Call
(
op
,
args
)
BuiltinCall
(
const
std
::
vector
<
Node
::
ptr
>&
args
)
:
Call
(
args
)
{
}
};
...
...
@@ -139,24 +108,23 @@ namespace ngraph
{
public
:
AbsCall
(
const
Node
::
ptr
&
arg0
)
:
BuiltinCall
(
s_op
,
{
arg0
})
:
BuiltinCall
({
arg0
})
{
}
protected
:
static
BuiltinOp
s_op
;
virtual
std
::
string
op_name
()
const
override
{
return
"abs"
;
}
//virtual void propagate_types() override;
};
class
AddCall
:
public
BuiltinCall
{
public
:
AddCall
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
:
BuiltinCall
(
s_op
,
{
arg0
,
arg1
})
:
BuiltinCall
({
arg0
,
arg1
})
{
}
//virtual void propagate_types() override;
protected
:
static
BuiltinOp
s_op
;
virtual
std
::
string
op_name
()
const
override
{
return
"add"
;
}
//virtual void propagate_types() override;
};
class
BroadcastCall
:
public
BuiltinCall
...
...
@@ -169,43 +137,42 @@ namespace ngraph
** the remaining axes in shape must be the same as the shape of arg.
**/
BroadcastCall
(
const
Node
::
ptr
&
arg
,
const
Shape
&
shape
,
std
::
vector
<
size_t
>
broadcast_axes
)
:
BuiltinCall
(
s_op
,
{
arg
})
:
BuiltinCall
({
arg
})
,
m_shape
(
shape
)
,
m_broadcast_axes
(
broadcast_axes
)
{
}
virtual
std
::
string
op_name
()
const
override
{
return
"broadcast"
;
}
virtual
void
propagate_types
()
override
;
protected
:
Shape
m_shape
;
std
::
vector
<
size_t
>
m_broadcast_axes
;
static
BuiltinOp
s_op
;
};
class
CeilingCall
:
public
BuiltinCall
{
public
:
CeilingCall
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
:
BuiltinCall
(
s_op
,
{
arg0
,
arg1
})
:
BuiltinCall
({
arg0
,
arg1
})
{
}
virtual
std
::
string
op_name
()
const
override
{
return
"ceiling"
;
}
//virtual void propagate_types() override;
protected
:
static
BuiltinOp
s_op
;
};
class
DivideCall
:
public
BuiltinCall
{
public
:
DivideCall
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
:
BuiltinCall
(
s_op
,
{
arg0
,
arg1
})
:
BuiltinCall
({
arg0
,
arg1
})
{
}
virtual
std
::
string
op_name
()
const
override
{
return
"divide"
;
}
//virtual void propagate_types() override;
protected
:
static
BuiltinOp
s_op
;
};
class
DotCall
:
public
BuiltinCall
...
...
@@ -213,183 +180,182 @@ namespace ngraph
public
:
/// TODO: Semantics of arg0 and arg1 axes wrt reduction.
DotCall
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
:
BuiltinCall
(
s_op
,
{
arg0
,
arg1
})
:
BuiltinCall
({
arg0
,
arg1
})
{
}
virtual
std
::
string
op_name
()
const
override
{
return
"dot"
;
}
virtual
void
propagate_types
()
override
;
protected
:
static
BuiltinOp
s_op
;
};
class
EqualCall
:
public
BuiltinCall
{
public
:
EqualCall
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
:
BuiltinCall
(
s_op
,
{
arg0
,
arg1
})
:
BuiltinCall
({
arg0
,
arg1
})
{
}
virtual
std
::
string
op_name
()
const
override
{
return
"equal"
;
}
//virtual void propagate_types() override;
protected
:
static
BuiltinOp
s_op
;
};
class
ExponentialCall
:
public
BuiltinCall
{
public
:
ExponentialCall
(
const
Node
::
ptr
&
arg0
)
:
BuiltinCall
(
s_op
,
{
arg0
})
:
BuiltinCall
({
arg0
})
{
}
virtual
std
::
string
op_name
()
const
override
{
return
"exp"
;
}
//virtual void propagate_types() override;
protected
:
static
BuiltinOp
s_op
;
};
class
FloorCall
:
public
BuiltinCall
{
public
:
FloorCall
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
:
BuiltinCall
(
s_op
,
{
arg0
,
arg1
})
:
BuiltinCall
({
arg0
,
arg1
})
{
}
virtual
std
::
string
op_name
()
const
override
{
return
"floor"
;
}
//virtual void propagate_types() override;
protected
:
static
BuiltinOp
s_op
;
};
class
GreaterCall
:
public
BuiltinCall
{
public
:
GreaterCall
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
:
BuiltinCall
(
s_op
,
{
arg0
,
arg1
})
:
BuiltinCall
({
arg0
,
arg1
})
{
}
virtual
std
::
string
op_name
()
const
override
{
return
"greater"
;
}
//virtual void propagate_types() override;
protected
:
static
BuiltinOp
s_op
;
};
class
LessCall
:
public
BuiltinCall
{
public
:
LessCall
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
:
BuiltinCall
(
s_op
,
{
arg0
,
arg1
})
:
BuiltinCall
({
arg0
,
arg1
})
{
}
virtual
std
::
string
op_name
()
const
override
{
return
"less"
;
}
//virtual void propagate_types() override;
protected
:
static
BuiltinOp
s_op
;
};
class
LogCall
:
public
BuiltinCall
{
public
:
LogCall
(
const
Node
::
ptr
&
arg0
)
:
BuiltinCall
(
s_op
,
{
arg0
})
:
BuiltinCall
({
arg0
})
{
}
virtual
std
::
string
op_name
()
const
override
{
return
"log"
;
}
//virtual void propagate_types() override;
protected
:
static
BuiltinOp
s_op
;
};
class
MaximumCall
:
public
BuiltinCall
{
public
:
MaximumCall
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
:
BuiltinCall
(
s_op
,
{
arg0
,
arg1
})
:
BuiltinCall
({
arg0
,
arg1
})
{
}
virtual
std
::
string
op_name
()
const
override
{
return
"max"
;
}
//virtual void propagate_types() override;
protected
:
static
BuiltinOp
s_op
;
};
class
MinimumCall
:
public
BuiltinCall
{
public
:
MinimumCall
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
:
BuiltinCall
(
s_op
,
{
arg0
,
arg1
})
:
BuiltinCall
({
arg0
,
arg1
})
{
}
virtual
std
::
string
op_name
()
const
override
{
return
"min"
;
}
//virtual void propagate_types() override;
protected
:
static
BuiltinOp
s_op
;
};
class
MultiplyCall
:
public
BuiltinCall
{
public
:
MultiplyCall
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
:
BuiltinCall
(
s_op
,
{
arg0
,
arg1
})
:
BuiltinCall
({
arg0
,
arg1
})
{
}
virtual
std
::
string
op_name
()
const
override
{
return
"multiply"
;
}
//virtual void propagate_types() override;
protected
:
static
BuiltinOp
s_op
;
};
class
NegateCall
:
public
BuiltinCall
{
public
:
NegateCall
(
const
Node
::
ptr
&
arg0
)
:
BuiltinCall
(
s_op
,
{
arg0
})
:
BuiltinCall
({
arg0
})
{
}
virtual
std
::
string
op_name
()
const
override
{
return
"negate"
;
}
//virtual void propagate_types() override;
protected
:
static
BuiltinOp
s_op
;
};
class
PowerCall
:
public
BuiltinCall
{
public
:
PowerCall
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
:
BuiltinCall
(
s_op
,
{
arg0
,
arg1
})
:
BuiltinCall
({
arg0
,
arg1
})
{
}
virtual
std
::
string
op_name
()
const
override
{
return
"power"
;
}
//virtual void propagate_types() override;
protected
:
static
BuiltinOp
s_op
;
};
class
RemainderCall
:
public
BuiltinCall
{
public
:
RemainderCall
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
:
BuiltinCall
(
s_op
,
{
arg0
,
arg1
})
:
BuiltinCall
({
arg0
,
arg1
})
{
}
virtual
std
::
string
op_name
()
const
override
{
return
"remainder"
;
}
//virtual void propagate_types() override;
protected
:
static
BuiltinOp
s_op
;
};
class
ReshapeCall
:
public
BuiltinCall
{
public
:
ReshapeCall
(
const
Node
::
ptr
&
arg0
,
const
Shape
&
shape
)
:
BuiltinCall
(
s_op
,
{
arg0
})
:
BuiltinCall
({
arg0
})
,
m_shape
(
shape
)
{
}
virtual
std
::
string
op_name
()
const
override
{
return
"reshape"
;
}
//virtual void propagate_types() override;
protected
:
Shape
m_shape
;
static
BuiltinOp
s_op
;
};
class
SubtractCall
:
public
BuiltinCall
{
public
:
SubtractCall
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
:
BuiltinCall
(
s_op
,
{
arg0
,
arg1
})
:
BuiltinCall
({
arg0
,
arg1
})
{
}
virtual
std
::
string
op_name
()
const
override
{
return
"subtract"
;
}
//virtual void propagate_types() override;
protected
:
static
BuiltinOp
s_op
;
};
}
src/ops/op.cpp
View file @
1b026daa
...
...
@@ -19,22 +19,16 @@
using
namespace
ngraph
;
using
namespace
std
;
BuiltinOp
AbsCall
::
s_op
=
BuiltinOp
(
"abs"
);
Node
::
ptr
ngraph
::
op
::
abs
(
const
Node
::
ptr
&
arg
)
{
return
make_shared
<
AbsCall
>
(
arg
);
}
BuiltinOp
AddCall
::
s_op
=
BuiltinOp
(
"add"
);
Node
::
ptr
ngraph
::
op
::
add
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
{
return
make_shared
<
AddCall
>
(
arg0
,
arg1
);
}
BuiltinOp
BroadcastCall
::
s_op
=
BuiltinOp
(
"broadcast"
);
/**
** /param arg The tensor view to be broadcast.
** /param shape The shape of the result
...
...
@@ -74,8 +68,6 @@ void BroadcastCall::propagate_types()
m_type
=
make_shared
<
TensorViewType
>
(
arg_tensor_view_type
->
element_type
(),
m_shape
);
}
BuiltinOp
CeilingCall
::
s_op
=
BuiltinOp
(
"ceiling"
);
Node
::
ptr
ngraph
::
op
::
ceiling
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
{
return
make_shared
<
CeilingCall
>
(
arg0
,
arg1
);
...
...
@@ -86,15 +78,11 @@ Node::ptr ngraph::op::ceiling(const Node::ptr& arg0, const Node::ptr& arg1)
// 'convert',
// 'convolution',
BuiltinOp
DivideCall
::
s_op
=
BuiltinOp
(
"divide"
);
Node
::
ptr
ngraph
::
op
::
divide
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
{
return
make_shared
<
DivideCall
>
(
arg0
,
arg1
);
}
BuiltinOp
DotCall
::
s_op
=
BuiltinOp
(
"dot"
);
/// TODO: Semantics of arg0 and arg1 axes wrt reduction.
Node
::
ptr
ngraph
::
op
::
dot
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
{
...
...
@@ -139,50 +127,36 @@ void DotCall::propagate_types()
m_type
=
make_shared
<
TensorViewType
>
(
arg0_tensor_type
->
element_type
(),
result_shape
);
}
BuiltinOp
ExponentialCall
::
s_op
=
BuiltinOp
(
"exponential"
);
Node
::
ptr
ngraph
::
op
::
exponential
(
const
Node
::
ptr
&
arg0
)
{
return
make_shared
<
ExponentialCall
>
(
arg0
);
}
BuiltinOp
FloorCall
::
s_op
=
BuiltinOp
(
"floor"
);
Node
::
ptr
ngraph
::
op
::
floor
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
{
return
make_shared
<
FloorCall
>
(
arg0
,
arg1
);
}
BuiltinOp
LogCall
::
s_op
=
BuiltinOp
(
"log"
);
Node
::
ptr
ngraph
::
op
::
log
(
const
Node
::
ptr
&
arg0
)
{
return
make_shared
<
LogCall
>
(
arg0
);
}
BuiltinOp
MaximumCall
::
s_op
=
BuiltinOp
(
"maximum"
);
Node
::
ptr
ngraph
::
op
::
maximum
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
{
return
make_shared
<
MaximumCall
>
(
arg0
,
arg1
);
}
BuiltinOp
MinimumCall
::
s_op
=
BuiltinOp
(
"minimum"
);
Node
::
ptr
ngraph
::
op
::
minimum
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
{
return
make_shared
<
MinimumCall
>
(
arg0
,
arg1
);
}
BuiltinOp
MultiplyCall
::
s_op
=
BuiltinOp
(
"multiply"
);
Node
::
ptr
ngraph
::
op
::
multiply
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
{
return
make_shared
<
MultiplyCall
>
(
arg0
,
arg1
);
}
BuiltinOp
NegateCall
::
s_op
=
BuiltinOp
(
"negate"
);
Node
::
ptr
ngraph
::
op
::
negate
(
const
Node
::
ptr
&
arg0
)
{
return
make_shared
<
NegateCall
>
(
arg0
);
...
...
@@ -191,8 +165,6 @@ Node::ptr ngraph::op::negate(const Node::ptr& arg0)
// 'pad',
// 'parameter',
BuiltinOp
PowerCall
::
s_op
=
BuiltinOp
(
"power"
);
Node
::
ptr
ngraph
::
op
::
power
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
{
return
make_shared
<
PowerCall
>
(
arg0
,
arg1
);
...
...
@@ -200,15 +172,11 @@ Node::ptr ngraph::op::power(const Node::ptr& arg0, const Node::ptr& arg1)
//'reduce',
BuiltinOp
RemainderCall
::
s_op
=
BuiltinOp
(
"remainder"
);
Node
::
ptr
ngraph
::
op
::
remainder
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
{
return
make_shared
<
RemainderCall
>
(
arg0
,
arg1
);
}
BuiltinOp
ReshapeCall
::
s_op
=
BuiltinOp
(
"reshape"
);
Node
::
ptr
ngraph
::
op
::
reshape
(
const
Node
::
ptr
&
arg0
,
const
Shape
&
shape
)
{
return
make_shared
<
ReshapeCall
>
(
arg0
,
shape
);
...
...
@@ -219,8 +187,6 @@ Node::ptr ngraph::op::reshape(const Node::ptr& arg0, const Shape& shape)
// 'select',
//'slice',
BuiltinOp
SubtractCall
::
s_op
=
BuiltinOp
(
"subtract"
);
Node
::
ptr
ngraph
::
op
::
subtract
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
{
return
make_shared
<
SubtractCall
>
(
arg0
,
arg1
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment