Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
f756fa0d
Commit
f756fa0d
authored
7 years ago
by
Scott Cyphers
Committed by
GitHub
7 years ago
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'master' into jmenon/fix
parents
7b9461ae
64412792
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
104 additions
and
7 deletions
+104
-7
tensor_view_layout.hpp
src/ngraph/descriptor/layout/tensor_view_layout.hpp
+5
-0
convert.cpp
src/ngraph/ops/convert.cpp
+3
-2
convert.hpp
src/ngraph/ops/convert.hpp
+2
-2
convert.hpp
src/ngraph/runtime/eigen/convert.hpp
+51
-0
external_function.cpp
src/ngraph/runtime/external_function.cpp
+0
-0
execute.cpp
test/execute.cpp
+0
-0
type_prop.cpp
test/type_prop.cpp
+43
-3
No files found.
src/ngraph/descriptor/layout/tensor_view_layout.hpp
View file @
f756fa0d
...
...
@@ -51,6 +51,11 @@ namespace ngraph
/// With non-linear buffers, this will need to be something other than size_t.
virtual
size_t
get_index_offset
(
const
std
::
vector
<
size_t
>&
indices
)
=
0
;
const
element
::
Type
&
get_element_type
()
const
{
return
m_tensor_view
.
get_tensor_view_type
()
->
get_element_type
();
}
const
Shape
&
get_shape
()
const
{
return
m_tensor_view
.
get_tensor_view_type
()
->
get_shape
();
...
...
This diff is collapsed.
Click to expand it.
src/ngraph/ops/convert.cpp
View file @
f756fa0d
...
...
@@ -17,9 +17,10 @@
#include "ngraph/ngraph.hpp"
using
namespace
std
;
using
namespace
ngraph
;
using
namespace
ngraph
::
op
;
void
Convert
::
propagate_types
()
const
element
::
Type
&
Convert
::
propagate_element_types
(
const
element
::
Type
&
arg_element_type
)
const
{
throw
ngraph_error
(
"NIY"
)
;
return
m_element_type
;
}
This diff is collapsed.
Click to expand it.
src/ngraph/ops/convert.hpp
View file @
f756fa0d
...
...
@@ -27,9 +27,9 @@ namespace ngraph
{
}
virtual
const
element
::
Type
&
propagate_element_types
(
const
element
::
Type
&
arg_element_type
)
const
override
;
virtual
std
::
string
description
()
const
override
{
return
"Convert"
;
}
virtual
void
propagate_types
()
override
;
protected
:
const
ngraph
::
element
::
Type
&
m_element_type
;
};
...
...
This diff is collapsed.
Click to expand it.
src/ngraph/runtime/eigen/convert.hpp
0 → 100644
View file @
f756fa0d
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
#include "ngraph/runtime/tensor_view_info.hpp"
namespace
ngraph
{
namespace
runtime
{
namespace
eigen
{
template
<
typename
ETI
,
typename
ETO
>
class
ConvertInstruction
:
public
Instruction
{
public
:
ConvertInstruction
(
const
TensorViewInfo
&
arg
,
const
TensorViewInfo
&
out
)
:
m_arg
(
arg
)
,
m_out
(
out
)
{
}
virtual
void
execute
(
CallFrame
&
call_frame
)
const
override
{
EigenArray1d
<
ETO
>
(
call_frame
,
m_out
)
=
EigenArray1d
<
ETI
>
(
call_frame
,
m_arg
).
template
cast
<
typename
ETO
::
type
>
();
}
protected
:
TensorViewInfo
m_arg
;
TensorViewInfo
m_out
;
};
}
}
}
This diff is collapsed.
Click to expand it.
src/ngraph/runtime/external_function.cpp
View file @
f756fa0d
This diff is collapsed.
Click to expand it.
test/execute.cpp
View file @
f756fa0d
This diff is collapsed.
Click to expand it.
test/type_prop.cpp
View file @
f756fa0d
...
...
@@ -237,9 +237,49 @@ TEST(type_prop, concat_deduce_elem_type_mismatch)
}
}
//
// Tests for dot product.
//
TEST
(
type_prop
,
convert_deduce
)
{
// Deduce type
auto
param
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
Shape
{
2
,
3
,
4
});
auto
c
=
make_shared
<
op
::
Convert
>
(
param
,
element
::
Int32
::
element_type
());
c
->
propagate_types
();
auto
c_vt
=
c
->
get_value_type
();
ASSERT_EQ
(
*
c_vt
,
TensorViewType
(
element
::
Int32
::
element_type
(),
Shape
{
2
,
3
,
4
}));
}
TEST
(
type_prop
,
convert_deduce_correct
)
{
// Check deduced type against incorrectly specified type
auto
param
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
Shape
{
2
,
3
,
4
});
auto
c
=
make_shared
<
op
::
Convert
>
(
param
,
element
::
Int32
::
element_type
());
c
->
set_value_type
(
make_shared
<
TensorViewType
>
(
element
::
Int32
::
element_type
(),
Shape
{
2
,
3
,
4
}));
c
->
propagate_types
();
auto
c_vt
=
c
->
get_value_type
();
ASSERT_EQ
(
*
c_vt
,
TensorViewType
(
element
::
Int32
::
element_type
(),
Shape
{
2
,
3
,
4
}));
}
TEST
(
type_prop
,
convert_deduce_incorrect
)
{
// Check deduced type against incorrectly specified type
auto
param
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
Shape
{
2
,
3
,
4
});
auto
c
=
make_shared
<
op
::
Convert
>
(
param
,
element
::
Int32
::
element_type
());
c
->
set_value_type
(
make_shared
<
TensorViewType
>
(
element
::
Int32
::
element_type
(),
Shape
{
2
,
14
,
4
}));
try
{
c
->
propagate_types
();
// Should have thrown, so fail if it didn't
FAIL
()
<<
"Deduced type should disagree with specified type"
;
}
catch
(
const
ngraph_error
&
error
)
{
EXPECT_EQ
(
error
.
what
(),
std
::
string
(
"Setting value type to a different ValueType"
));
}
catch
(...)
{
FAIL
()
<<
"Deduced type check failed for unexpected reason"
;
}
}
TEST
(
type_prop
,
dot_deduce_scalar_2d
)
{
// Deduce type for scalar/matrix arguments
...
...
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment