Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
f7bfd75e
Commit
f7bfd75e
authored
Sep 14, 2017
by
Bob Kimball
Browse files
Options
Browse Files
Download
Plain Diff
merge master
parents
bf022034
70f8c112
Hide whitespace changes
Inline
Side-by-side
Showing
22 changed files
with
626 additions
and
56 deletions
+626
-56
CMakeLists.txt
src/CMakeLists.txt
+6
-0
common.hpp
src/ngraph/common.hpp
+7
-0
buffer.hpp
src/ngraph/descriptor/buffer.hpp
+30
-0
call_frame.hpp
src/ngraph/descriptor/call_frame.hpp
+41
-0
tensor_view_layout.hpp
src/ngraph/descriptor/tensor_view_layout.hpp
+12
-3
ngraph.hpp
src/ngraph/ngraph.hpp
+6
-0
constant.hpp
src/ngraph/ops/constant.hpp
+55
-1
call_frame.cpp
src/ngraph/runtime/call_frame.cpp
+29
-0
call_frame.hpp
src/ngraph/runtime/call_frame.hpp
+59
-0
tensor_view.cpp
src/ngraph/runtime/eigen/tensor_view.cpp
+29
-0
tensor_view.hpp
src/ngraph/runtime/eigen/tensor_view.hpp
+100
-0
function.hpp
src/ngraph/runtime/function.hpp
+37
-0
tensor_view.hpp
src/ngraph/runtime/tensor_view.hpp
+28
-0
shape.cpp
src/ngraph/shape.cpp
+27
-5
shape.hpp
src/ngraph/shape.hpp
+6
-25
constant.cpp
src/ops/constant.cpp
+3
-0
element_type.cpp
src/types/element_type.cpp
+1
-2
type.cpp
src/types/type.cpp
+2
-1
CMakeLists.txt
test/CMakeLists.txt
+2
-0
build_graph.cpp
test/build_graph.cpp
+63
-19
runtime.cpp
test/runtime.cpp
+46
-0
shape.cpp
test/shape.cpp
+37
-0
No files found.
src/CMakeLists.txt
View file @
f7bfd75e
...
...
@@ -11,6 +11,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
include_directories
(
SYSTEM
${
EIGEN_INCLUDE_DIR
}
)
set
(
SRC
log.cpp
ngraph/descriptor/input.cpp
...
...
@@ -28,6 +30,9 @@ set (SRC
ngraph/pass/propagate_types.cpp
ngraph/pass/topological_sort.cpp
ngraph/pass/tree_pass.cpp
ngraph/runtime/call_frame.cpp
ngraph/runtime/eigen/tensor_view.cpp
ngraph/shape.cpp
ngraph/visualize.cpp
ops/binary_elementwise_builtin.cpp
ops/broadcast.cpp
...
...
@@ -92,3 +97,4 @@ install(DIRECTORY
FILES_MATCHING PATTERN
"*.hpp"
)
add_dependencies
(
ngraph eigen
)
src/ngraph/common.hpp
View file @
f7bfd75e
...
...
@@ -41,4 +41,11 @@ namespace ngraph
/// A set of axes, for example, reduction axes
using
AxisSet
=
std
::
set
<
size_t
>
;
/// Shape for a tensor
using
Shape
=
std
::
vector
<
size_t
>
;
/// Strides of a tensor
using
Strides
=
std
::
vector
<
size_t
>
;
}
src/ngraph/descriptor/buffer.hpp
0 → 100644
View file @
f7bfd75e
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace
ngraph
{
namespace
descriptor
{
// A buffer identfies a chunk of storage
// In descriptors, we are identifying what will be associated with actual memory
// during execution.
class
Buffer
{
protected
:
size_t
size
;
};
}
}
src/ngraph/descriptor/call_frame.hpp
0 → 100644
View file @
f7bfd75e
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <memory>
#include <vector>
#include "ngraph/descriptor/tensor_view.hpp"
#include "ngraph/function.hpp"
namespace
ngraph
{
namespace
descriptor
{
// Describes the frame that will be used when a function is executing
class
CallFrame
{
protected
:
Function
m_function
;
// Will be provided by the caller
std
::
vector
<
std
::
shared_ptr
<
TensorView
>>
m_inputs
;
std
::
vector
<
std
::
shared_ptr
<
TensorView
>>
m_outputs
;
// Will be provided by the call mechanism
// Expect there to be only one buffer
std
::
vector
<
std
::
shared_ptr
<
Buffer
>>
m_buffers
;
};
}
}
src/ngraph/descriptor/tensor_view_layout.hpp
View file @
f7bfd75e
...
...
@@ -20,12 +20,21 @@ namespace ngraph
{
namespace
descriptor
{
using
Strides
=
std
::
vector
<
size_t
>
;
// An interface for describing implementations of tensor views
// Kernel selection will need to pay attention to the layout
class
TensorViewLayout
{
public
:
virtual
~
TensorViewLayout
()
{}
};
// The standard strided layout
class
DenseTensorViewLayout
:
public
TensorViewLayout
{
protected
:
Strides
m_strides
;
std
::
shared_ptr
<
Buffer
>
m_buffer
;
Strides
m_strides
;
size_t
m_offset
;
};
}
}
src/ngraph/ngraph.hpp
View file @
f7bfd75e
...
...
@@ -19,6 +19,8 @@
#pragma once
#include "ngraph/common.hpp"
#include "ngraph/descriptor/buffer.hpp"
#include "ngraph/descriptor/call_frame.hpp"
#include "ngraph/descriptor/input.hpp"
#include "ngraph/descriptor/output.hpp"
#include "ngraph/descriptor/tensor.hpp"
...
...
@@ -42,5 +44,9 @@
#include "ngraph/ops/parameter.hpp"
#include "ngraph/ops/subtract.hpp"
#include "ngraph/ops/tuple.hpp"
#include "ngraph/runtime/eigen/tensor_view.hpp"
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/function.hpp"
#include "ngraph/runtime/tensor_view.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/type.hpp"
src/ngraph/ops/constant.hpp
View file @
f7bfd75e
...
...
@@ -17,6 +17,7 @@
#include <sstream>
#include "ngraph/element_type.hpp"
#include "ngraph/runtime/eigen/tensor_view.hpp"
namespace
ngraph
{
...
...
@@ -59,7 +60,10 @@ namespace ngraph
return
ss
.
str
();
}
typename
T
::
type
get_value
()
const
{
return
m_value
;
}
type
get_value
()
const
{
return
m_value
;
}
protected
:
typename
T
::
type
m_value
;
...
...
@@ -72,5 +76,55 @@ namespace ngraph
using
UInt8ScalarConstant
=
ScalarConstant
<
element
::
UInt8
>
;
using
UInt32ScalarConstant
=
ScalarConstant
<
element
::
UInt32
>
;
using
UInt64ScalarConstant
=
ScalarConstant
<
element
::
UInt64
>
;
// Defines methods to all constant tensors
class
TensorConstantBase
:
public
Node
{
protected
:
TensorConstantBase
(
const
std
::
shared_ptr
<
TensorViewType
>&
type
)
:
Node
({},
type
)
{
}
virtual
void
propagate_types
()
override
;
};
// Implement a constant tensor for each element type.
template
<
typename
T
>
class
TensorConstant
:
public
TensorConstantBase
{
public
:
// The ngraph element type
using
element_type
=
T
;
// The C++ type that holds the element type
using
type
=
typename
T
::
type
;
TensorConstant
(
const
Shape
&
shape
)
:
TensorConstantBase
(
std
::
make_shared
<
TensorViewType
>
(
T
::
element_type
(),
shape
))
,
m_value
(
std
::
make_shared
<
ngraph
::
runtime
::
eigen
::
PrimaryTensorView
<
T
>>
(
shape
))
{
}
virtual
std
::
string
description
()
const
override
{
return
"TensorConstant"
;
}
virtual
std
::
string
get_node_id
()
const
override
{
std
::
stringstream
ss
;
ss
<<
description
()
<<
"_"
/* << node_id() */
;
return
ss
.
str
();
}
typename
std
::
shared_ptr
<
ngraph
::
runtime
::
eigen
::
PrimaryTensorView
<
T
>>
get_value
()
const
{
return
m_value
;
}
protected
:
std
::
shared_ptr
<
ngraph
::
runtime
::
eigen
::
PrimaryTensorView
<
T
>>
m_value
;
};
using
Float32TensorConstant
=
TensorConstant
<
element
::
Float32
>
;
using
Int8TensorConstant
=
TensorConstant
<
element
::
Int8
>
;
using
Int32TensorConstant
=
TensorConstant
<
element
::
Int32
>
;
using
Int64TensorConstant
=
TensorConstant
<
element
::
Int64
>
;
using
UInt8TensorConstant
=
TensorConstant
<
element
::
UInt8
>
;
using
UInt32TensorConstant
=
TensorConstant
<
element
::
UInt32
>
;
using
UInt64TensorConstant
=
TensorConstant
<
element
::
UInt64
>
;
}
}
src/ngraph/runtime/call_frame.cpp
0 → 100644
View file @
f7bfd75e
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/ngraph.hpp"
using
namespace
std
;
using
namespace
ngraph
;
using
namespace
runtime
;
CallFrame
::
CallFrame
(
Function
&
function
,
const
std
::
vector
<
std
::
shared_ptr
<
PrimaryTensorView
>>&
arguments
,
const
std
::
vector
<
std
::
shared_ptr
<
PrimaryTensorView
>>&
results
)
{
m_tensors
.
insert
(
m_tensors
.
end
(),
arguments
.
begin
(),
arguments
.
end
());
m_tensors
.
insert
(
m_tensors
.
end
(),
results
.
begin
(),
results
.
end
());
// TBD
// From Function allocate tensors for the temporaries
}
src/ngraph/runtime/call_frame.hpp
0 → 100644
View file @
f7bfd75e
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <memory>
#include <vector>
#include "ngraph/runtime/function.hpp"
namespace
ngraph
{
namespace
runtime
{
class
CallFrameAccessor
;
// This is constructed when a runtime function is called.
class
CallFrame
{
friend
class
CallFrameAccessor
;
public
:
CallFrame
(
Function
&
function
,
const
std
::
vector
<
std
::
shared_ptr
<
PrimaryTensorView
>>&
arguments
,
const
std
::
vector
<
std
::
shared_ptr
<
PrimaryTensorView
>>&
results
);
protected
:
std
::
vector
<
std
::
shared_ptr
<
PrimaryTensorView
>>
m_tensors
;
};
class
CallFrameAccessor
{
public
:
CallFrameAccessor
(
size_t
index
)
:
m_index
(
index
)
{
}
std
::
shared_ptr
<
PrimaryTensorView
>
operator
()(
CallFrame
&
call_frame
)
{
return
call_frame
.
m_tensors
[
m_index
];
}
protected
:
size_t
m_index
;
};
}
}
src/ngraph/runtime/eigen/tensor_view.cpp
0 → 100644
View file @
f7bfd75e
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <Eigen/Dense>
#include "ngraph.hpp"
using
namespace
Eigen
;
using
namespace
ngraph
::
runtime
::
eigen
;
using
namespace
ngraph
::
element
;
template
void
ngraph
::
runtime
::
eigen
::
add
<
Float32
>
(
const
PrimaryTensorView
<
Float32
>&
arg0
,
const
PrimaryTensorView
<
Float32
>&
arg1
,
PrimaryTensorView
<
Float32
>&
out
);
template
void
ngraph
::
runtime
::
eigen
::
multiply
<
Float32
>
(
const
PrimaryTensorView
<
Float32
>&
arg0
,
const
PrimaryTensorView
<
Float32
>&
arg1
,
PrimaryTensorView
<
Float32
>&
out
);
src/ngraph/runtime/eigen/tensor_view.hpp
0 → 100644
View file @
f7bfd75e
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <Eigen/Dense>
#include <vector>
#include "ngraph/shape.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace
ngraph
{
namespace
runtime
{
namespace
eigen
{
template
<
typename
ET
>
class
PrimaryTensorView
:
public
ngraph
::
runtime
::
PrimaryTensorView
{
public
:
// Standard definitions from vector
using
value_type
=
typename
ET
::
type
;
using
storage_type
=
std
::
vector
<
value_type
>
;
using
size_type
=
typename
storage_type
::
size_type
;
using
difference_type
=
typename
storage_type
::
difference_type
;
using
reference
=
typename
storage_type
::
reference
;
using
const_reference
=
typename
storage_type
::
const_reference
;
using
pointer
=
typename
storage_type
::
pointer
;
using
const_pointer
=
typename
storage_type
::
const_pointer
;
using
iterator
=
typename
storage_type
::
iterator
;
using
const_iterator
=
typename
storage_type
::
const_iterator
;
using
reverse_iterator
=
typename
storage_type
::
reverse_iterator
;
using
const_reverse_iterator
=
typename
storage_type
::
const_reverse_iterator
;
// Mapping vector to eigen
using
eigen_type
=
Eigen
::
Array
<
value_type
,
Eigen
::
Dynamic
,
1
>
;
using
eigen_map
=
Eigen
::
Map
<
eigen_type
>
;
PrimaryTensorView
(
const
ngraph
::
Shape
&
shape
)
:
m_shape
(
shape
)
,
m_size
(
ngraph
::
shape_size
(
shape
))
,
m_strides
(
ngraph
::
row_major_strides
(
m_shape
))
,
m_vector
(
m_size
,
0
)
,
m_map
(
&
m_vector
[
0
],
m_size
,
1
)
{
}
template
<
typename
T
>
PrimaryTensorView
&
operator
=
(
const
T
&
value
)
{
m_vector
=
value
;
return
*
this
;
}
// For getting the data out
const
storage_type
&
get_vector
()
{
return
m_vector
;
}
eigen_map
&
get_map
()
{
return
m_map
;
}
const
eigen_map
&
get_map
()
const
{
return
m_map
;
}
const
Shape
&
get_shape
()
const
{
return
m_shape
;
}
protected
:
ngraph
::
Shape
m_shape
;
size_t
m_size
;
ngraph
::
Strides
m_strides
;
storage_type
m_vector
;
eigen_map
m_map
;
};
template
<
typename
ET
>
void
add
(
const
PrimaryTensorView
<
ET
>&
arg0
,
const
PrimaryTensorView
<
ET
>&
arg1
,
PrimaryTensorView
<
ET
>&
out
)
{
out
.
get_map
()
=
arg0
.
get_map
()
+
arg1
.
get_map
();
}
template
<
typename
ET
>
void
multiply
(
const
PrimaryTensorView
<
ET
>&
arg0
,
const
PrimaryTensorView
<
ET
>&
arg1
,
PrimaryTensorView
<
ET
>&
out
)
{
out
.
get_map
()
=
arg0
.
get_map
()
*
arg1
.
get_map
();
}
}
}
}
src/ngraph/runtime/function.hpp
0 → 100644
View file @
f7bfd75e
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <memory>
#include <vector>
#include "ngraph/runtime/tensor_view.hpp"
namespace
ngraph
{
namespace
runtime
{
// A compiled graph function
class
Function
{
public
:
virtual
~
Function
()
{}
// Invoke the function with a the given inputs and outputs
void
operator
()(
std
::
vector
<
std
::
shared_ptr
<
PrimaryTensorView
>>
inputs
,
std
::
vector
<
std
::
shared_ptr
<
PrimaryTensorView
>>
outputs
);
};
}
}
src/ngraph/runtime/tensor_view.hpp
0 → 100644
View file @
f7bfd75e
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace
ngraph
{
namespace
runtime
{
// Actual tensor views are parameterized on element type
class
PrimaryTensorView
{
public
:
virtual
~
PrimaryTensorView
(){}
};
}
}
src/ngraph/shape.cpp
View file @
f7bfd75e
...
...
@@ -12,11 +12,33 @@
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include
"shape.hpp"
#include
"util.hpp"
#include
<algorithm>
#include
<vector>
std
::
ostream
&
ngraph
::
operator
<<
(
std
::
ostream
&
out
,
const
ngraph
::
Shape
&
obj
)
#include "ngraph/shape.hpp"
using
namespace
std
;
using
namespace
ngraph
;
size_t
ngraph
::
shape_size
(
const
Shape
&
shape
)
{
size_t
size
=
1
;
for
(
auto
d
:
shape
)
{
size
*=
d
;
}
return
size
;
}
Strides
ngraph
::
row_major_strides
(
const
Shape
&
shape
)
{
out
<<
"{"
<<
join
(
obj
.
m_sizes
,
", "
)
<<
"}"
;
return
out
;
Strides
strides
;
size_t
s
=
1
;
for
(
auto
d
=
shape
.
rbegin
();
d
!=
shape
.
rend
();
d
++
)
{
strides
.
push_back
(
s
);
s
*=
*
d
;
}
reverse
(
strides
.
begin
(),
strides
.
end
());
return
strides
;
}
src/ngraph/shape.hpp
View file @
f7bfd75e
...
...
@@ -18,32 +18,13 @@
#include <iostream>
#include <vector>
#include "common.hpp"
namespace
ngraph
{
/**
** Holds the shape of a tensor view.
**/
class
Shape
{
public
:
/// @param sizes A sequence of sizes.
Shape
(
const
std
::
initializer_list
<
size_t
>&
sizes
)
:
m_sizes
(
sizes
)
{
}
Shape
(
const
std
::
vector
<
size_t
>&
sizes
)
:
m_sizes
(
sizes
)
{
}
/// Conversion to a vector of sizes.
operator
const
std
::
vector
<
size_t
>&
()
const
{
return
m_sizes
;
}
bool
operator
==
(
const
Shape
&
shape
)
const
{
return
m_sizes
==
shape
.
m_sizes
;
}
bool
operator
!=
(
const
Shape
&
shape
)
const
{
return
m_sizes
!=
shape
.
m_sizes
;
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
,
const
Shape
&
);
/// Number of elements in spanned by a shape
size_t
shape_size
(
const
Shape
&
shape
);
protected
:
std
::
vector
<
size_t
>
m_sizes
;
};
/// Row-major strides for a shape
Strides
row_major_strides
(
const
Shape
&
shape
);
}
src/ops/constant.cpp
View file @
f7bfd75e
...
...
@@ -17,3 +17,6 @@
using
namespace
ngraph
::
op
;
void
ScalarConstantBase
::
propagate_types
()
{}
void
TensorConstantBase
::
propagate_types
()
{}
src/types/element_type.cpp
View file @
f7bfd75e
...
...
@@ -55,4 +55,4 @@ std::ostream& ngraph::element::operator<<(std::ostream& out, const ngraph::eleme
{
out
<<
obj
.
m_cname
;
return
out
;
}
\ No newline at end of file
}
src/types/type.cpp
View file @
f7bfd75e
...
...
@@ -16,6 +16,7 @@
#include "ngraph/ngraph.hpp"
#include "log.hpp"
#include "util.hpp"
using
namespace
std
;
using
namespace
ngraph
;
...
...
@@ -69,7 +70,7 @@ std::ostream& ngraph::operator<<(std::ostream& out, const ValueType& obj)
std
::
ostream
&
ngraph
::
operator
<<
(
std
::
ostream
&
out
,
const
TensorViewType
&
obj
)
{
out
<<
"TensorViewType("
<<
obj
.
m_element_type
<<
",
"
<<
obj
.
m_shape
<<
"
)"
;
out
<<
"TensorViewType("
<<
obj
.
m_element_type
<<
",
{"
<<
join
(
obj
.
m_shape
)
<<
"}
)"
;
return
out
;
}
...
...
test/CMakeLists.txt
View file @
f7bfd75e
...
...
@@ -31,6 +31,8 @@ set (SRC
pass_liveness.cpp
pass_manager.cpp
pass_memory_layout.cpp
runtime.cpp
shape.cpp
tensor.cpp
test_tools.cpp
topological_sort.cpp
...
...
test/build_graph.cpp
View file @
f7bfd75e
...
...
@@ -98,40 +98,84 @@ TEST(build_graph, literal)
ASSERT_NE
(
*
int32_0
->
get_value_type
(),
*
float_scalar_type
);
}
TEST
(
build_graph
,
tensor
)
{
// float scalar from a float
//auto float0 = FloatScalarConstant::make(3.0);
auto
float0
=
make_shared
<
op
::
Float32TensorConstant
>
(
Shape
{
2
,
3
});
auto
float_tensor_type
=
make_shared
<
TensorViewType
>
(
element
::
Float32
::
element_type
(),
Shape
{
2
,
3
});
ASSERT_EQ
(
*
float0
->
get_value_type
(),
*
float_tensor_type
);
auto
d
=
make_shared
<
op
::
Dot
>
(
float0
,
float0
);
ASSERT_EQ
(
d
->
get_arguments
().
at
(
0
),
float0
);
ASSERT_EQ
(
d
->
get_arguments
().
at
(
1
),
float0
);
auto
int32_0
=
make_shared
<
op
::
Int32TensorConstant
>
(
Shape
{
3
,
5
});
auto
int32_tensor_type
=
make_shared
<
TensorViewType
>
(
element
::
Int32
::
element_type
(),
Shape
{
3
,
5
});
ASSERT_EQ
(
*
int32_0
->
get_value_type
(),
*
int32_tensor_type
);
ASSERT_NE
(
*
int32_0
->
get_value_type
(),
*
float_tensor_type
);
}
TEST
(
build_graph
,
set_value_type_checked
)
{
auto
untyped_param
=
make_shared
<
op
::
Parameter
>
();
try
{
untyped_param
->
set_value_type_checked
(
make_shared
<
TensorViewType
>
(
element
::
Float32
::
element_type
(),
Shape
{
4
,
4
}));
}
catch
(...){
try
{
untyped_param
->
set_value_type_checked
(
make_shared
<
TensorViewType
>
(
element
::
Float32
::
element_type
(),
Shape
{
4
,
4
}));
}
catch
(...)
{
FAIL
()
<<
"Setting value type for first time type failed."
;
}
try
{
untyped_param
->
set_value_type_checked
(
make_shared
<
TensorViewType
>
(
element
::
Float32
::
element_type
(),
Shape
{
4
,
4
}));
}
catch
(...){
try
{
untyped_param
->
set_value_type_checked
(
make_shared
<
TensorViewType
>
(
element
::
Float32
::
element_type
(),
Shape
{
4
,
4
}));
}
catch
(...)
{
FAIL
()
<<
"Setting value type to same type failed."
;
}
try
{
untyped_param
->
set_value_type_checked
(
make_shared
<
TensorViewType
>
(
element
::
Float32
::
element_type
(),
Shape
{
4
,
5
}));
try
{
untyped_param
->
set_value_type_checked
(
make_shared
<
TensorViewType
>
(
element
::
Float32
::
element_type
(),
Shape
{
4
,
5
}));
FAIL
()
<<
"Setting value type to a different shape did not fail."
;
}
catch
(
const
ngraph_error
&
error
){
}
catch
(
const
ngraph_error
&
error
)
{
EXPECT_EQ
(
error
.
what
(),
std
::
string
(
"Setting value type to a different ValueType"
));
}
catch
(...){
}
catch
(...)
{
FAIL
()
<<
"Setting value type to a different shape did not failed with incorrect error."
;
}
try
{
untyped_param
->
set_value_type_checked
(
make_shared
<
TensorViewType
>
(
element
::
Int32
::
element_type
(),
Shape
{
4
,
4
}));
try
{
untyped_param
->
set_value_type_checked
(
make_shared
<
TensorViewType
>
(
element
::
Int32
::
element_type
(),
Shape
{
4
,
4
}));
FAIL
()
<<
"Setting value type to a different element type did not fail."
;
}
catch
(
const
ngraph_error
&
error
){
}
catch
(
const
ngraph_error
&
error
)
{
EXPECT_EQ
(
error
.
what
(),
std
::
string
(
"Setting value type to a different ValueType"
));
}
catch
(...){
FAIL
()
<<
"Setting value type to a different element type did not failed with incorrect error."
;
}
catch
(...)
{
FAIL
()
<<
"Setting value type to a different element type did not failed with incorrect "
"error."
;
}
auto
param
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
Shape
{
4
,
4
});
try
{
param
->
set_value_type_checked
(
make_shared
<
TensorViewType
>
(
element
::
Float32
::
element_type
(),
Shape
{
4
,
4
}));
}
catch
(...){
try
{
param
->
set_value_type_checked
(
make_shared
<
TensorViewType
>
(
element
::
Float32
::
element_type
(),
Shape
{
4
,
4
}));
}
catch
(...)
{
FAIL
()
<<
"Setting value type to same type failed."
;
}
}
...
...
test/runtime.cpp
0 → 100644
View file @
f7bfd75e
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <memory>
#include <vector>
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
using
namespace
std
;
using
namespace
ngraph
;
using
namespace
ngraph
::
runtime
::
eigen
;
TEST
(
runtime
,
test_add
)
{
auto
x
=
make_shared
<
PrimaryTensorView
<
element
::
Float32
>>
(
Shape
{
2
,
2
});
*
x
=
std
::
vector
<
float
>
{
1
,
2
,
3
,
4
};
auto
y
=
make_shared
<
PrimaryTensorView
<
element
::
Float32
>>
(
Shape
{
2
,
2
});
*
y
=
std
::
vector
<
float
>
{
5
,
6
,
7
,
8
};
auto
z
=
make_shared
<
PrimaryTensorView
<
element
::
Float32
>>
(
Shape
{
2
,
2
});
add
(
*
x
,
*
y
,
*
z
);
ASSERT_EQ
((
vector
<
float
>
{
6
,
8
,
10
,
12
}),
z
->
get_vector
());
}
TEST
(
runtime
,
test_multiply
)
{
auto
x
=
make_shared
<
op
::
Float32TensorConstant
>
(
Shape
{
2
,
2
});
*
x
->
get_value
()
=
std
::
vector
<
float
>
{
1
,
2
,
3
,
4
};
auto
y
=
make_shared
<
op
::
Float32TensorConstant
>
(
Shape
{
2
,
2
});
*
y
->
get_value
()
=
std
::
vector
<
float
>
{
5
,
6
,
7
,
8
};
auto
z
=
make_shared
<
op
::
Float32TensorConstant
>
(
Shape
{
2
,
2
});
multiply
(
*
x
->
get_value
(),
*
y
->
get_value
(),
*
z
->
get_value
());
ASSERT_EQ
((
vector
<
float
>
{
5
,
12
,
21
,
32
}),
z
->
get_value
()
->
get_vector
());
}
test/shape.cpp
0 → 100644
View file @
f7bfd75e
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <memory>
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
using
namespace
std
;
using
namespace
ngraph
;
using
namespace
ngraph
::
runtime
::
eigen
;
TEST
(
shape
,
test_shape_size
)
{
ASSERT_EQ
(
1
,
shape_size
(
Shape
{}));
ASSERT_EQ
(
2
*
3
*
5
,
shape_size
(
Shape
{
2
,
3
,
5
}));
}
TEST
(
shape
,
test_shape_strides
)
{
ASSERT_EQ
(
Strides
{},
row_major_strides
(
Shape
{}));
ASSERT_EQ
(
Strides
{
1
},
row_major_strides
(
Shape
{
3
}));
ASSERT_EQ
((
Strides
{
7
,
1
}),
row_major_strides
(
Shape
{
2
,
7
}));
ASSERT_EQ
((
Strides
{
84
,
12
,
1
}),
row_major_strides
(
Shape
{
5
,
7
,
12
}));
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment