Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
a136956b
Commit
a136956b
authored
Aug 26, 2017
by
Scott Cyphers
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Two shape propagates/checks, bulk of ops.
parent
9d40c6b2
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
540 additions
and
29 deletions
+540
-29
element_type.hpp
src/ngraph/element_type.hpp
+16
-15
except.hpp
src/ngraph/except.hpp
+34
-0
function.hpp
src/ngraph/function.hpp
+2
-0
ngraph.hpp
src/ngraph/ngraph.hpp
+1
-0
node.hpp
src/ngraph/node.hpp
+1
-1
op.hpp
src/ngraph/op.hpp
+265
-8
shape.hpp
src/ngraph/shape.hpp
+8
-0
type.hpp
src/ngraph/type.hpp
+10
-0
function.cpp
src/ops/function.cpp
+8
-0
op.cpp
src/ops/op.cpp
+194
-4
build_graph.cpp
test/build_graph.cpp
+1
-1
No files found.
src/ngraph/element_type.hpp
View file @
a136956b
...
...
@@ -29,7 +29,7 @@ namespace ngraph
{
public
:
Type
(
size_t
bitwidth
,
bool
is_float
,
bool
is_signed
,
const
std
::
string
&
cname
);
const
std
::
string
&
c_type_string
()
const
;
size_t
size
()
const
;
size_t
hash
()
const
...
...
@@ -37,23 +37,24 @@ namespace ngraph
std
::
hash
<
std
::
string
>
h
;
return
h
(
m_cname
);
}
bool
operator
==
(
const
Type
&
other
)
const
;
bool
operator
!=
(
const
Type
&
other
)
const
{
return
!
(
*
this
==
other
);
}
private
:
static
std
::
map
<
std
::
string
,
Type
>
m_element_list
;
size_t
m_bitwidth
;
bool
m_is_float
;
bool
m_is_signed
;
const
std
::
string
m_cname
;
static
std
::
map
<
std
::
string
,
Type
>
m_element_list
;
size_t
m_bitwidth
;
bool
m_is_float
;
bool
m_is_signed
;
const
std
::
string
m_cname
;
};
const
Type
float32_t
=
Type
(
32
,
true
,
true
,
"float"
);
const
Type
int8_t
=
Type
(
8
,
false
,
true
,
"int8_t"
);
const
Type
int32_t
=
Type
(
32
,
false
,
true
,
"int32_t"
);
const
Type
int64_t
=
Type
(
64
,
false
,
true
,
"int64_t"
);
const
Type
uint8_t
=
Type
(
8
,
false
,
false
,
"int8_t"
);
const
Type
uint32_t
=
Type
(
32
,
false
,
false
,
"int32_t"
);
const
Type
uint64_t
=
Type
(
64
,
false
,
false
,
"int64_t"
);
const
Type
float32_t
=
Type
(
32
,
true
,
true
,
"float"
);
const
Type
int8_t
=
Type
(
8
,
false
,
true
,
"int8_t"
);
const
Type
int32_t
=
Type
(
32
,
false
,
true
,
"int32_t"
);
const
Type
int64_t
=
Type
(
64
,
false
,
true
,
"int64_t"
);
const
Type
uint8_t
=
Type
(
8
,
false
,
false
,
"int8_t"
);
const
Type
uint32_t
=
Type
(
32
,
false
,
false
,
"int32_t"
);
const
Type
uint64_t
=
Type
(
64
,
false
,
false
,
"int64_t"
);
}
}
src/ngraph/except.hpp
0 → 100644
View file @
a136956b
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <stdexcept>
namespace
ngraph
{
/// Base error for ngraph runtime errors.
struct
ngraph_error
:
std
::
runtime_error
{
explicit
ngraph_error
(
const
std
::
string
&
what_arg
)
:
std
::
runtime_error
(
what_arg
)
{
}
explicit
ngraph_error
(
const
char
*
what_arg
)
:
std
::
runtime_error
(
what_arg
)
{
}
};
}
src/ngraph/function.hpp
View file @
a136956b
...
...
@@ -35,6 +35,8 @@ namespace ngraph
std
::
string
description
()
const
override
{
return
"Parameter"
;
}
virtual
void
propagate_types
()
override
;
protected
:
Function
&
m_function
;
size_t
m_index
;
...
...
src/ngraph/ngraph.hpp
View file @
a136956b
...
...
@@ -19,6 +19,7 @@
#pragma once
#include "ngraph/element_type.hpp"
#include "ngraph/except.hpp"
#include "ngraph/function.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op.hpp"
...
...
src/ngraph/node.hpp
View file @
a136956b
...
...
@@ -53,7 +53,7 @@ namespace ngraph
virtual
std
::
string
description
()
const
=
0
;
/// Propagate types and check arguments for consistency
//
virtual void propagate_types() = 0;
virtual
void
propagate_types
()
=
0
;
const
std
::
vector
<
Node
::
ptr
>
arguments
()
const
{
return
m_arguments
;
}
std
::
vector
<
Node
::
ptr
>
arguments
()
{
return
m_arguments
;
}
...
...
src/ngraph/op.hpp
View file @
a136956b
...
...
@@ -21,6 +21,49 @@
namespace
ngraph
{
namespace
op
{
Node
::
ptr
abs
(
const
Node
::
ptr
&
arg
);
Node
::
ptr
add
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
);
Node
::
ptr
broadcast
(
const
Node
::
ptr
&
tensor
,
const
Shape
&
shape
,
const
std
::
vector
<
size_t
>&
broadcast_axes
);
//Node::ptr candidate();
Node
::
ptr
ceiling
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
);
//Node::ptr concatenate();
//Node::ptr constant();
//Node::ptr convert();
//Node::ptr convolution();
Node
::
ptr
divide
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
);
Node
::
ptr
dot
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
);
Node
::
ptr
equal
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
);
Node
::
ptr
exponential
(
const
Node
::
ptr
&
arg0
);
Node
::
ptr
floor
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
);
//Node::ptr get();
Node
::
ptr
greater
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
);
Node
::
ptr
less
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
);
Node
::
ptr
log
(
const
Node
::
ptr
&
arg0
);
//Node::ptr logical();
Node
::
ptr
maximum
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
);
Node
::
ptr
minimum
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
);
Node
::
ptr
multiply
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
);
Node
::
ptr
negate
(
const
Node
::
ptr
&
arg0
);
//Node::ptr pad();
Node
::
ptr
power
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
);
//Node::ptr reduce();
Node
::
ptr
remainder
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
);
Node
::
ptr
reshape
(
const
Node
::
ptr
&
arg0
,
const
Shape
&
shape
);
//Node::ptr reverse();
//Node::ptr rng();
//Node::ptr select();
//Node::ptr slice();
Node
::
ptr
subtract
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
);
//Node::ptr transpose();
//Node::ptr tuple();
//Node::ptr while();
}
/**
** Every instance of Op corresponds to a unique defined operation.
**/
...
...
@@ -82,6 +125,9 @@ namespace ngraph
public
:
virtual
std
::
string
description
()
const
override
{
return
"BuiltinCall"
;
}
// TODO: Implement for each op
virtual
void
propagate_types
()
override
{}
protected
:
BuiltinCall
(
const
std
::
shared_ptr
<
Op
>&
op
,
const
std
::
vector
<
Node
::
ptr
>&
args
)
:
Call
(
op
,
args
)
...
...
@@ -89,12 +135,29 @@ namespace ngraph
}
};
namespace
op
class
AbsCall
:
public
BuiltinCall
{
std
::
shared_ptr
<
Node
>
broadcast
(
const
Node
::
ptr
&
tensor
,
const
Shape
&
shape
,
const
std
::
vector
<
size_t
>&
broadcast_axes
);
}
public
:
AbsCall
(
const
Node
::
ptr
&
arg0
)
:
BuiltinCall
(
s_op
,
{
arg0
})
{
}
protected
:
static
std
::
shared_ptr
<
BuiltinOp
>
s_op
;
};
class
AddCall
:
public
BuiltinCall
{
public
:
AddCall
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
:
BuiltinCall
(
s_op
,
{
arg0
,
arg1
})
{
}
//virtual void propagate_types() override;
protected
:
static
std
::
shared_ptr
<
BuiltinOp
>
s_op
;
};
class
BroadcastCall
:
public
BuiltinCall
{
...
...
@@ -111,17 +174,39 @@ namespace ngraph
,
m_broadcast_axes
(
broadcast_axes
)
{
}
virtual
void
propagate_types
()
override
;
protected
:
Shape
m_shape
;
std
::
vector
<
size_t
>
m_broadcast_axes
;
static
std
::
shared_ptr
<
BuiltinOp
>
s_op
;
};
class
CeilingCall
:
public
BuiltinCall
{
public
:
CeilingCall
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
:
BuiltinCall
(
s_op
,
{
arg0
,
arg1
})
{
}
//virtual void propagate_types() override;
protected
:
static
std
::
shared_ptr
<
BuiltinOp
>
s_op
;
};
namespace
op
class
DivideCall
:
public
BuiltinCall
{
std
::
shared_ptr
<
Node
>
dot
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
);
}
public
:
DivideCall
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
:
BuiltinCall
(
s_op
,
{
arg0
,
arg1
})
{
}
//virtual void propagate_types() override;
protected
:
static
std
::
shared_ptr
<
BuiltinOp
>
s_op
;
};
class
DotCall
:
public
BuiltinCall
{
...
...
@@ -131,7 +216,179 @@ namespace ngraph
:
BuiltinCall
(
s_op
,
{
arg0
,
arg1
})
{
}
virtual
void
propagate_types
()
override
;
protected
:
static
std
::
shared_ptr
<
BuiltinOp
>
s_op
;
};
class
EqualCall
:
public
BuiltinCall
{
public
:
EqualCall
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
:
BuiltinCall
(
s_op
,
{
arg0
,
arg1
})
{
}
//virtual void propagate_types() override;
protected
:
static
std
::
shared_ptr
<
BuiltinOp
>
s_op
;
};
class
ExponentialCall
:
public
BuiltinCall
{
public
:
ExponentialCall
(
const
Node
::
ptr
&
arg0
)
:
BuiltinCall
(
s_op
,
{
arg0
})
{
}
//virtual void propagate_types() override;
protected
:
static
std
::
shared_ptr
<
BuiltinOp
>
s_op
;
};
class
FloorCall
:
public
BuiltinCall
{
public
:
FloorCall
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
:
BuiltinCall
(
s_op
,
{
arg0
,
arg1
})
{
}
//virtual void propagate_types() override;
protected
:
static
std
::
shared_ptr
<
BuiltinOp
>
s_op
;
};
class
GreaterCall
:
public
BuiltinCall
{
public
:
GreaterCall
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
:
BuiltinCall
(
s_op
,
{
arg0
,
arg1
})
{
}
//virtual void propagate_types() override;
protected
:
static
std
::
shared_ptr
<
BuiltinOp
>
s_op
;
};
class
LessCall
:
public
BuiltinCall
{
public
:
LessCall
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
:
BuiltinCall
(
s_op
,
{
arg0
,
arg1
})
{
}
//virtual void propagate_types() override;
protected
:
static
std
::
shared_ptr
<
BuiltinOp
>
s_op
;
};
class
LogCall
:
public
BuiltinCall
{
public
:
LogCall
(
const
Node
::
ptr
&
arg0
)
:
BuiltinCall
(
s_op
,
{
arg0
})
{
}
//virtual void propagate_types() override;
protected
:
static
std
::
shared_ptr
<
BuiltinOp
>
s_op
;
};
class
MaximumCall
:
public
BuiltinCall
{
public
:
MaximumCall
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
:
BuiltinCall
(
s_op
,
{
arg0
,
arg1
})
{
}
//virtual void propagate_types() override;
protected
:
static
std
::
shared_ptr
<
BuiltinOp
>
s_op
;
};
class
MinimumCall
:
public
BuiltinCall
{
public
:
MinimumCall
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
:
BuiltinCall
(
s_op
,
{
arg0
,
arg1
})
{
}
//virtual void propagate_types() override;
protected
:
static
std
::
shared_ptr
<
BuiltinOp
>
s_op
;
};
class
MultiplyCall
:
public
BuiltinCall
{
public
:
MultiplyCall
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
:
BuiltinCall
(
s_op
,
{
arg0
,
arg1
})
{
}
//virtual void propagate_types() override;
protected
:
static
std
::
shared_ptr
<
BuiltinOp
>
s_op
;
};
class
NegateCall
:
public
BuiltinCall
{
public
:
NegateCall
(
const
Node
::
ptr
&
arg0
)
:
BuiltinCall
(
s_op
,
{
arg0
})
{
}
//virtual void propagate_types() override;
protected
:
static
std
::
shared_ptr
<
BuiltinOp
>
s_op
;
};
class
PowerCall
:
public
BuiltinCall
{
public
:
PowerCall
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
:
BuiltinCall
(
s_op
,
{
arg0
,
arg1
})
{
}
//virtual void propagate_types() override;
protected
:
static
std
::
shared_ptr
<
BuiltinOp
>
s_op
;
};
class
RemainderCall
:
public
BuiltinCall
{
public
:
RemainderCall
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
:
BuiltinCall
(
s_op
,
{
arg0
,
arg1
})
{
}
//virtual void propagate_types() override;
protected
:
static
std
::
shared_ptr
<
BuiltinOp
>
s_op
;
};
class
ReshapeCall
:
public
BuiltinCall
{
public
:
ReshapeCall
(
const
Node
::
ptr
&
arg0
,
const
Shape
&
shape
)
:
BuiltinCall
(
s_op
,
{
arg0
})
,
m_shape
(
shape
)
{
}
//virtual void propagate_types() override;
protected
:
Shape
m_shape
;
static
std
::
shared_ptr
<
BuiltinOp
>
s_op
;
};
class
SubtractCall
:
public
BuiltinCall
{
public
:
SubtractCall
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
:
BuiltinCall
(
s_op
,
{
arg0
,
arg1
})
{
}
//virtual void propagate_types() override;
protected
:
static
std
::
shared_ptr
<
BuiltinOp
>
s_op
;
};
...
...
src/ngraph/shape.hpp
View file @
a136956b
...
...
@@ -32,11 +32,19 @@ namespace ngraph
{
}
Shape
(
const
std
::
vector
<
size_t
>&
sizes
)
:
m_sizes
(
sizes
)
{
}
/**
** Conversion to a vector of sizes.
**/
operator
const
std
::
vector
<
size_t
>&
()
const
{
return
m_sizes
;
}
bool
operator
==
(
const
Shape
&
shape
)
const
{
return
m_sizes
==
shape
.
m_sizes
;
}
bool
operator
!=
(
const
Shape
&
shape
)
const
{
return
m_sizes
!=
shape
.
m_sizes
;
}
protected
:
std
::
vector
<
size_t
>
m_sizes
;
};
...
...
src/ngraph/type.hpp
View file @
a136956b
...
...
@@ -22,6 +22,9 @@
namespace
ngraph
{
class
TensorViewType
;
class
TupleType
;
/**
** ValueType is
** TensorViewType
...
...
@@ -34,6 +37,10 @@ namespace ngraph
** Preferred handle
**/
using
ptr
=
std
::
shared_ptr
<
ValueType
>
;
virtual
~
ValueType
()
{}
virtual
std
::
shared_ptr
<
TensorViewType
>
as_tensor_view_type
()
{
return
nullptr
;
}
virtual
std
::
shared_ptr
<
TupleType
>
as_tuple_type
()
{
return
nullptr
;
}
};
/**
...
...
@@ -57,6 +64,9 @@ namespace ngraph
{
}
const
element
::
Type
&
element_type
()
const
{
return
m_element_type
;
}
const
Shape
shape
()
const
{
return
m_shape
;
}
protected
:
const
element
::
Type
&
m_element_type
;
Shape
m_shape
;
...
...
src/ops/function.cpp
View file @
a136956b
...
...
@@ -24,6 +24,14 @@ Parameter::Parameter(Function& function, size_t index)
{
}
void
Parameter
::
propagate_types
()
{
if
(
m_type
==
nullptr
)
{
throw
ngraph_error
{
"Unitialized parameter"
};
}
}
Function
::
Function
(
size_t
n_parameters
)
:
m_parameters
(
n_parameters
)
,
m_name
(
"Function"
)
...
...
src/ops/op.cpp
View file @
a136956b
...
...
@@ -12,11 +12,27 @@
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <algorithm>
#include "ngraph/ngraph.hpp"
using
namespace
ngraph
;
using
namespace
std
;
std
::
shared_ptr
<
BuiltinOp
>
AbsCall
::
s_op
=
make_shared
<
BuiltinOp
>
(
"abs"
);
Node
::
ptr
ngraph
::
op
::
abs
(
const
Node
::
ptr
&
arg
)
{
return
make_shared
<
AbsCall
>
(
arg
);
}
std
::
shared_ptr
<
BuiltinOp
>
AddCall
::
s_op
=
make_shared
<
BuiltinOp
>
(
"add"
);
Node
::
ptr
ngraph
::
op
::
add
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
{
return
make_shared
<
AddCall
>
(
arg0
,
arg1
);
}
std
::
shared_ptr
<
BuiltinOp
>
BroadcastCall
::
s_op
=
make_shared
<
BuiltinOp
>
(
"broadcast"
);
/**
...
...
@@ -25,17 +41,191 @@ std::shared_ptr<BuiltinOp> BroadcastCall::s_op = make_shared<BuiltinOp>("broadca
** /param broadcast_axes The axis positions (0-based) in the result that are being broadcast.
** the remaining axes in shape must be the same as the shape of arg.
**/
shared_ptr
<
Node
>
ngraph
::
op
::
broadcast
(
const
Node
::
ptr
&
tensor
,
const
Shape
&
shape
,
const
vector
<
size_t
>&
broadcast_axes
)
Node
::
ptr
ngraph
::
op
::
broadcast
(
const
Node
::
ptr
&
tensor
,
const
Shape
&
shape
,
const
vector
<
size_t
>&
broadcast_axes
)
{
return
make_shared
<
BroadcastCall
>
(
tensor
,
shape
,
broadcast_axes
);
}
void
BroadcastCall
::
propagate_types
()
{
auto
arg_type
=
m_arguments
.
at
(
0
)
->
type
();
if
(
nullptr
==
arg_type
)
{
throw
ngraph_error
(
"Argument to broadcast is missing type."
);
}
auto
arg_tensor_view_type
=
arg_type
->
as_tensor_view_type
();
if
(
nullptr
==
arg_tensor_view_type
)
{
throw
ngraph_error
(
"Argument to broadcast is not a tensor view"
);
}
vector
<
size_t
>
target_shape
=
m_shape
;
for
(
auto
i
=
m_broadcast_axes
.
rbegin
();
i
!=
m_broadcast_axes
.
rend
();
++
i
)
{
target_shape
.
erase
(
target_shape
.
begin
()
+
*
i
);
}
if
(
Shape
{
target_shape
}
!=
arg_tensor_view_type
->
shape
())
{
throw
ngraph_error
(
"Broadcast arg, shape, and axes are incompatible"
);
}
// TODO If m_type is already set (by framework), this should verify that the type
// we expect is consistent with the type the framework expects.
m_type
=
make_shared
<
TensorViewType
>
(
arg_tensor_view_type
->
element_type
(),
m_shape
);
}
std
::
shared_ptr
<
BuiltinOp
>
CeilingCall
::
s_op
=
make_shared
<
BuiltinOp
>
(
"ceiling"
);
Node
::
ptr
ngraph
::
op
::
ceiling
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
{
return
make_shared
<
CeilingCall
>
(
arg0
,
arg1
);
}
// 'concatenate',
// 'constant',
// 'convert',
// 'convolution',
std
::
shared_ptr
<
BuiltinOp
>
DivideCall
::
s_op
=
make_shared
<
BuiltinOp
>
(
"divide"
);
Node
::
ptr
ngraph
::
op
::
divide
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
{
return
make_shared
<
DivideCall
>
(
arg0
,
arg1
);
}
std
::
shared_ptr
<
BuiltinOp
>
DotCall
::
s_op
=
make_shared
<
BuiltinOp
>
(
"dot"
);
/// TODO: Semantics of arg0 and arg1 axes wrt reduction.
shared_ptr
<
Node
>
ngraph
::
op
::
dot
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
Node
::
ptr
ngraph
::
op
::
dot
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
{
return
make_shared
<
DotCall
>
(
arg0
,
arg1
);
}
void
DotCall
::
propagate_types
()
{
auto
arg0_tensor_type
=
m_arguments
.
at
(
0
)
->
type
()
->
as_tensor_view_type
();
auto
arg1_tensor_type
=
m_arguments
.
at
(
1
)
->
type
()
->
as_tensor_view_type
();
if
(
nullptr
==
arg0_tensor_type
||
nullptr
==
arg1_tensor_type
)
{
throw
ngraph_error
(
"Arguments to dot must be tensor views"
);
}
if
(
arg0_tensor_type
->
element_type
()
!=
arg1_tensor_type
->
element_type
())
{
throw
ngraph_error
(
"Arguments to dot must have the same element type"
);
}
// Use NumPy semantics for now
// Last axis of first arg reduces against second to last of second arg if more than one axis, else axis.
vector
<
size_t
>
arg0_shape
=
arg0_tensor_type
->
shape
();
vector
<
size_t
>
arg1_shape
=
arg1_tensor_type
->
shape
();
size_t
arg0_reduction
=
arg0_shape
.
size
()
-
1
;
size_t
arg1_reduction
;
if
(
arg1_shape
.
size
()
>
1
)
{
arg1_reduction
=
arg1_shape
.
size
()
-
2
;
}
else
{
arg1_reduction
=
arg1_shape
.
size
()
-
1
;
}
if
(
arg0_shape
.
at
(
arg0_reduction
)
!=
arg1_shape
.
at
(
arg1_reduction
))
{
throw
ngraph_error
(
"Dot reduction axes not compatible"
);
}
vector
<
size_t
>
result_shape
;
copy
(
arg0_shape
.
begin
(),
arg0_shape
.
begin
()
+
arg1_reduction
,
result_shape
.
end
());
copy
(
arg1_shape
.
begin
(),
arg1_shape
.
begin
()
+
arg1_reduction
,
result_shape
.
end
());
copy
(
arg1_shape
.
begin
()
+
arg1_reduction
,
arg1_shape
.
end
(),
result_shape
.
end
());
m_type
=
make_shared
<
TensorViewType
>
(
arg0_tensor_type
->
element_type
(),
result_shape
);
}
std
::
shared_ptr
<
BuiltinOp
>
ExponentialCall
::
s_op
=
make_shared
<
BuiltinOp
>
(
"exponential"
);
Node
::
ptr
ngraph
::
op
::
exponential
(
const
Node
::
ptr
&
arg0
)
{
return
make_shared
<
ExponentialCall
>
(
arg0
);
}
std
::
shared_ptr
<
BuiltinOp
>
FloorCall
::
s_op
=
make_shared
<
BuiltinOp
>
(
"floor"
);
Node
::
ptr
ngraph
::
op
::
floor
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
{
return
make_shared
<
FloorCall
>
(
arg0
,
arg1
);
}
std
::
shared_ptr
<
BuiltinOp
>
LogCall
::
s_op
=
make_shared
<
BuiltinOp
>
(
"log"
);
Node
::
ptr
ngraph
::
op
::
log
(
const
Node
::
ptr
&
arg0
)
{
return
make_shared
<
LogCall
>
(
arg0
);
}
std
::
shared_ptr
<
BuiltinOp
>
MaximumCall
::
s_op
=
make_shared
<
BuiltinOp
>
(
"maximum"
);
Node
::
ptr
ngraph
::
op
::
maximum
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
{
return
make_shared
<
MaximumCall
>
(
arg0
,
arg1
);
}
std
::
shared_ptr
<
BuiltinOp
>
MinimumCall
::
s_op
=
make_shared
<
BuiltinOp
>
(
"minimum"
);
Node
::
ptr
ngraph
::
op
::
minimum
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
{
return
make_shared
<
MinimumCall
>
(
arg0
,
arg1
);
}
std
::
shared_ptr
<
BuiltinOp
>
MultiplyCall
::
s_op
=
make_shared
<
BuiltinOp
>
(
"multiply"
);
Node
::
ptr
ngraph
::
op
::
multiply
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
{
return
make_shared
<
MultiplyCall
>
(
arg0
,
arg1
);
}
std
::
shared_ptr
<
BuiltinOp
>
NegateCall
::
s_op
=
make_shared
<
BuiltinOp
>
(
"negate"
);
Node
::
ptr
ngraph
::
op
::
negate
(
const
Node
::
ptr
&
arg0
)
{
return
make_shared
<
NegateCall
>
(
arg0
);
}
// 'pad',
// 'parameter',
std
::
shared_ptr
<
BuiltinOp
>
PowerCall
::
s_op
=
make_shared
<
BuiltinOp
>
(
"power"
);
Node
::
ptr
ngraph
::
op
::
power
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
{
return
make_shared
<
PowerCall
>
(
arg0
,
arg1
);
}
//'reduce',
std
::
shared_ptr
<
BuiltinOp
>
RemainderCall
::
s_op
=
make_shared
<
BuiltinOp
>
(
"remainder"
);
Node
::
ptr
ngraph
::
op
::
remainder
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
{
return
make_shared
<
RemainderCall
>
(
arg0
,
arg1
);
}
std
::
shared_ptr
<
BuiltinOp
>
ReshapeCall
::
s_op
=
make_shared
<
BuiltinOp
>
(
"reshape"
);
Node
::
ptr
ngraph
::
op
::
reshape
(
const
Node
::
ptr
&
arg0
,
const
Shape
&
shape
)
{
return
make_shared
<
ReshapeCall
>
(
arg0
,
shape
);
}
//'reverse',
//'rng',
// 'select',
//'slice',
std
::
shared_ptr
<
BuiltinOp
>
SubtractCall
::
s_op
=
make_shared
<
BuiltinOp
>
(
"subtract"
);
Node
::
ptr
ngraph
::
op
::
subtract
(
const
Node
::
ptr
&
arg0
,
const
Node
::
ptr
&
arg1
)
{
return
make_shared
<
SubtractCall
>
(
arg0
,
arg1
);
}
// 'transpose',
//'tuple',
// 'while'
test/build_graph.cpp
View file @
a136956b
...
...
@@ -19,7 +19,7 @@
using
namespace
std
;
using
namespace
ngraph
;
TEST
(
DISABLED_
graph
,
build_simple
)
TEST
(
n
graph
,
build_simple
)
{
// Function with 4 parameters
auto
cluster_0
=
make_shared
<
Function
>
(
4
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment