Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
7ed0fe7d
Commit
7ed0fe7d
authored
Sep 01, 2017
by
Scott Cyphers
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Switch to get/set
parent
c66a7469
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
76 additions
and
76 deletions
+76
-76
node.cpp
src/ngraph/node.cpp
+3
-3
node.hpp
src/ngraph/node.hpp
+5
-5
op.hpp
src/ngraph/op.hpp
+20
-20
broadcast.hpp
src/ngraph/ops/broadcast.hpp
+1
-1
concatenate.hpp
src/ngraph/ops/concatenate.hpp
+1
-1
constant.hpp
src/ngraph/ops/constant.hpp
+2
-2
convert.hpp
src/ngraph/ops/convert.hpp
+1
-1
dot.hpp
src/ngraph/ops/dot.hpp
+1
-1
parameter.hpp
src/ngraph/ops/parameter.hpp
+1
-1
tuple.hpp
src/ngraph/ops/tuple.hpp
+1
-1
type.hpp
src/ngraph/type.hpp
+8
-8
visualize.cpp
src/ngraph/visualize.cpp
+2
-2
broadcast.cpp
src/ops/broadcast.cpp
+3
-3
dot.cpp
src/ops/dot.cpp
+6
-6
op.cpp
src/ops/op.cpp
+2
-2
parameter.cpp
src/ops/parameter.cpp
+1
-1
type.cpp
src/types/type.cpp
+3
-3
util.cpp
src/util.cpp
+3
-3
build_graph.cpp
test/build_graph.cpp
+11
-11
topological_sort.cpp
test/topological_sort.cpp
+1
-1
No files found.
src/ngraph/node.cpp
View file @
7ed0fe7d
...
...
@@ -45,15 +45,15 @@ std::ostream& ngraph::operator<<(std::ostream& out, const ngraph::Node& node)
auto
parameter_tmp
=
dynamic_cast
<
const
ngraph
::
Op
*>
(
&
node
);
if
(
op_tmp
)
{
out
<<
"Op("
<<
op_tmp
->
node_id
()
<<
")"
;
out
<<
"Op("
<<
op_tmp
->
get_
node_id
()
<<
")"
;
}
else
if
(
parameter_tmp
)
{
out
<<
"Parameter("
<<
parameter_tmp
->
node_id
()
<<
")"
;
out
<<
"Parameter("
<<
parameter_tmp
->
get_
node_id
()
<<
")"
;
}
else
{
out
<<
"Node("
<<
node
.
node_id
()
<<
")"
;
out
<<
"Node("
<<
node
.
get_
node_id
()
<<
")"
;
}
return
out
;
}
src/ngraph/node.hpp
View file @
7ed0fe7d
...
...
@@ -48,14 +48,14 @@ namespace ngraph
/// Propagate types and check arguments for consistency
virtual
void
propagate_types
()
=
0
;
const
Nodes
&
arguments
()
const
{
return
m_arguments
;
}
const
Nodes
&
get_
arguments
()
const
{
return
m_arguments
;
}
const
std
::
multiset
<
Node
*>&
users
()
const
{
return
m_users
;
}
std
::
string
name
()
const
{
return
m_name
;
}
void
name
(
const
std
::
string
&
name
)
{
m_name
=
name
;
}
std
::
string
get_
name
()
const
{
return
m_name
;
}
void
set_
name
(
const
std
::
string
&
name
)
{
m_name
=
name
;
}
virtual
std
::
string
node_id
()
const
=
0
;
virtual
std
::
string
get_
node_id
()
const
=
0
;
/**
** Return true if this has the same implementing class as node. This
...
...
@@ -70,7 +70,7 @@ namespace ngraph
bool
is_op
()
const
;
bool
is_parameter
()
const
;
size_t
instance_id
()
const
{
return
m_instance_id
;
}
size_t
get_
instance_id
()
const
{
return
m_instance_id
;
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
,
const
Node
&
);
protected
:
...
...
src/ngraph/op.hpp
View file @
7ed0fe7d
...
...
@@ -74,8 +74,8 @@ namespace ngraph
{
}
virtual
std
::
string
op_class_name
()
const
=
0
;
virtual
std
::
string
node_id
()
const
override
;
virtual
std
::
string
get_
op_class_name
()
const
=
0
;
virtual
std
::
string
get_
node_id
()
const
override
;
};
/**
...
...
@@ -116,7 +116,7 @@ namespace ngraph
{
}
virtual
std
::
string
op_class_name
()
const
override
{
return
"abs"
;
}
virtual
std
::
string
get_
op_class_name
()
const
override
{
return
"abs"
;
}
//virtual void propagate_types() override;
};
...
...
@@ -127,7 +127,7 @@ namespace ngraph
:
BuiltinOp
({
arg0
,
arg1
})
{
}
virtual
std
::
string
op_class_name
()
const
override
{
return
"add"
;
}
virtual
std
::
string
get_
op_class_name
()
const
override
{
return
"add"
;
}
//virtual void propagate_types() override;
};
...
...
@@ -139,7 +139,7 @@ namespace ngraph
{
}
virtual
std
::
string
op_class_name
()
const
override
{
return
"ceiling"
;
}
virtual
std
::
string
get_
op_class_name
()
const
override
{
return
"ceiling"
;
}
//virtual void propagate_types() override;
};
...
...
@@ -151,7 +151,7 @@ namespace ngraph
{
}
virtual
std
::
string
op_class_name
()
const
override
{
return
"divide"
;
}
virtual
std
::
string
get_
op_class_name
()
const
override
{
return
"divide"
;
}
//virtual void propagate_types() override;
};
...
...
@@ -163,7 +163,7 @@ namespace ngraph
{
}
virtual
std
::
string
op_class_name
()
const
override
{
return
"equal"
;
}
virtual
std
::
string
get_
op_class_name
()
const
override
{
return
"equal"
;
}
//virtual void propagate_types() override;
};
...
...
@@ -175,7 +175,7 @@ namespace ngraph
{
}
virtual
std
::
string
op_class_name
()
const
override
{
return
"exp"
;
}
virtual
std
::
string
get_
op_class_name
()
const
override
{
return
"exp"
;
}
//virtual void propagate_types() override;
};
...
...
@@ -187,7 +187,7 @@ namespace ngraph
{
}
virtual
std
::
string
op_class_name
()
const
override
{
return
"floor"
;
}
virtual
std
::
string
get_
op_class_name
()
const
override
{
return
"floor"
;
}
//virtual void propagate_types() override;
};
...
...
@@ -199,7 +199,7 @@ namespace ngraph
{
}
virtual
std
::
string
op_class_name
()
const
override
{
return
"greater"
;
}
virtual
std
::
string
get_
op_class_name
()
const
override
{
return
"greater"
;
}
//virtual void propagate_types() override;
};
...
...
@@ -211,7 +211,7 @@ namespace ngraph
{
}
virtual
std
::
string
op_class_name
()
const
override
{
return
"less"
;
}
virtual
std
::
string
get_
op_class_name
()
const
override
{
return
"less"
;
}
//virtual void propagate_types() override;
};
...
...
@@ -223,7 +223,7 @@ namespace ngraph
{
}
virtual
std
::
string
op_class_name
()
const
override
{
return
"log"
;
}
virtual
std
::
string
get_
op_class_name
()
const
override
{
return
"log"
;
}
//virtual void propagate_types() override;
};
...
...
@@ -235,7 +235,7 @@ namespace ngraph
{
}
virtual
std
::
string
op_class_name
()
const
override
{
return
"max"
;
}
virtual
std
::
string
get_
op_class_name
()
const
override
{
return
"max"
;
}
//virtual void propagate_types() override;
};
...
...
@@ -247,7 +247,7 @@ namespace ngraph
{
}
virtual
std
::
string
op_class_name
()
const
override
{
return
"min"
;
}
virtual
std
::
string
get_
op_class_name
()
const
override
{
return
"min"
;
}
//virtual void propagate_types() override;
};
...
...
@@ -259,7 +259,7 @@ namespace ngraph
{
}
virtual
std
::
string
op_class_name
()
const
override
{
return
"multiply"
;
}
virtual
std
::
string
get_
op_class_name
()
const
override
{
return
"multiply"
;
}
//virtual void propagate_types() override;
};
...
...
@@ -271,7 +271,7 @@ namespace ngraph
{
}
virtual
std
::
string
op_class_name
()
const
override
{
return
"negative"
;
}
virtual
std
::
string
get_
op_class_name
()
const
override
{
return
"negative"
;
}
//virtual void propagate_types() override;
};
...
...
@@ -283,7 +283,7 @@ namespace ngraph
{
}
virtual
std
::
string
op_class_name
()
const
override
{
return
"power"
;
}
virtual
std
::
string
get_
op_class_name
()
const
override
{
return
"power"
;
}
//virtual void propagate_types() override;
};
...
...
@@ -295,7 +295,7 @@ namespace ngraph
{
}
virtual
std
::
string
op_class_name
()
const
override
{
return
"remainder"
;
}
virtual
std
::
string
get_
op_class_name
()
const
override
{
return
"remainder"
;
}
//virtual void propagate_types() override;
};
...
...
@@ -308,7 +308,7 @@ namespace ngraph
{
}
virtual
std
::
string
op_class_name
()
const
override
{
return
"reshape"
;
}
virtual
std
::
string
get_
op_class_name
()
const
override
{
return
"reshape"
;
}
//virtual void propagate_types() override;
protected
:
Shape
m_shape
;
...
...
@@ -322,7 +322,7 @@ namespace ngraph
{
}
virtual
std
::
string
op_class_name
()
const
override
{
return
"subtract"
;
}
virtual
std
::
string
get_
op_class_name
()
const
override
{
return
"subtract"
;
}
//virtual void propagate_types() override;
};
}
src/ngraph/ops/broadcast.hpp
View file @
7ed0fe7d
...
...
@@ -32,7 +32,7 @@ namespace ngraph
{
}
virtual
std
::
string
op_class_name
()
const
override
{
return
"broadcast"
;
}
virtual
std
::
string
get_
op_class_name
()
const
override
{
return
"broadcast"
;
}
virtual
void
propagate_types
()
override
;
protected
:
...
...
src/ngraph/ops/concatenate.hpp
View file @
7ed0fe7d
...
...
@@ -29,7 +29,7 @@ namespace ngraph
{
}
virtual
std
::
string
op_class_name
()
const
override
{
return
"concatenate"
;
}
virtual
std
::
string
get_
op_class_name
()
const
override
{
return
"concatenate"
;
}
virtual
void
propagate_types
()
override
;
};
}
src/ngraph/ops/constant.hpp
View file @
7ed0fe7d
...
...
@@ -50,14 +50,14 @@ namespace ngraph
}
virtual
std
::
string
description
()
const
override
{
return
"ScalarConstant"
;
}
virtual
std
::
string
node_id
()
const
override
virtual
std
::
string
get_
node_id
()
const
override
{
std
::
stringstream
ss
;
ss
<<
description
()
<<
"_"
/* << node_id() */
;
return
ss
.
str
();
}
typename
T
::
type
value
()
const
{
return
m_value
;
}
typename
T
::
type
get_
value
()
const
{
return
m_value
;
}
protected
:
typename
T
::
type
m_value
;
...
...
src/ngraph/ops/convert.hpp
View file @
7ed0fe7d
...
...
@@ -26,7 +26,7 @@ namespace ngraph
{
}
virtual
std
::
string
op_class_name
()
const
override
{
return
"convert"
;
}
virtual
std
::
string
get_
op_class_name
()
const
override
{
return
"convert"
;
}
virtual
void
propagate_types
()
override
;
protected
:
const
ngraph
::
element
::
Type
&
m_element_type
;
...
...
src/ngraph/ops/dot.hpp
View file @
7ed0fe7d
...
...
@@ -25,7 +25,7 @@ namespace ngraph
{
}
virtual
std
::
string
op_class_name
()
const
override
{
return
"dot"
;
}
virtual
std
::
string
get_
op_class_name
()
const
override
{
return
"dot"
;
}
virtual
void
propagate_types
()
override
;
};
...
...
src/ngraph/ops/parameter.hpp
View file @
7ed0fe7d
...
...
@@ -41,7 +41,7 @@ namespace ngraph
std
::
string
description
()
const
override
{
return
"Parameter"
;
}
virtual
void
propagate_types
()
override
;
virtual
std
::
string
node_id
()
const
override
;
virtual
std
::
string
get_
node_id
()
const
override
;
protected
:
Function
*
m_function
;
...
...
src/ngraph/ops/tuple.hpp
View file @
7ed0fe7d
...
...
@@ -29,7 +29,7 @@ namespace ngraph
{
}
virtual
std
::
string
op_class_name
()
const
override
{
return
"tuple"
;
}
virtual
std
::
string
get_
op_class_name
()
const
override
{
return
"tuple"
;
}
virtual
void
propagate_types
()
override
;
};
}
src/ngraph/type.hpp
View file @
7ed0fe7d
...
...
@@ -64,8 +64,8 @@ namespace ngraph
{
}
const
element
::
Type
&
element_type
()
const
{
return
m_element_type
;
}
const
Shape
&
shape
()
const
{
return
m_shape
;
}
const
element
::
Type
&
get_
element_type
()
const
{
return
m_element_type
;
}
const
Shape
&
get_
shape
()
const
{
return
m_shape
;
}
virtual
bool
operator
==
(
const
ValueType
::
ptr
&
that
)
const
override
;
...
...
@@ -97,8 +97,8 @@ namespace ngraph
{
}
const
std
::
vector
<
ValueType
::
ptr
>
element_types
()
const
{
return
m_element_types
;
}
std
::
vector
<
ValueType
::
ptr
>
element_types
()
{
return
m_element_types
;
}
const
std
::
vector
<
ValueType
::
ptr
>
get_
element_types
()
const
{
return
m_element_types
;
}
std
::
vector
<
ValueType
::
ptr
>
set_
element_types
()
{
return
m_element_types
;
}
virtual
bool
operator
==
(
const
ValueType
::
ptr
&
that
)
const
override
;
...
...
@@ -121,13 +121,13 @@ namespace ngraph
** Set the type
** /param type The new type
**/
void
value_type
(
const
ValueType
::
ptr
&
value_type
)
{
m_value_type
=
value_type
;
}
void
set_
value_type
(
const
ValueType
::
ptr
&
value_type
)
{
m_value_type
=
value_type
;
}
/**
** Set the type to be a tensor view type
** /param element_type The type of the tensor elements
** /param shape The shape of the view
**/
void
value_type
(
const
element
::
Type
&
element_type
,
const
Shape
&
shape
)
void
set_
value_type
(
const
element
::
Type
&
element_type
,
const
Shape
&
shape
)
{
m_value_type
=
std
::
make_shared
<
TensorViewType
>
(
element_type
,
shape
);
}
...
...
@@ -135,11 +135,11 @@ namespace ngraph
/**
** The type associated with this value.
**/
ValueType
::
ptr
value_type
()
{
return
m_value_type
;
}
ValueType
::
ptr
get_
value_type
()
{
return
m_value_type
;
}
/**
** The type associated with this value.
**/
const
ValueType
::
ptr
value_type
()
const
{
return
m_value_type
;
}
const
ValueType
::
ptr
get_
value_type
()
const
{
return
m_value_type
;
}
protected
:
ValueType
::
ptr
m_value_type
;
};
...
...
src/ngraph/visualize.cpp
View file @
7ed0fe7d
...
...
@@ -33,9 +33,9 @@ void Visualize::add(node_ptr p)
// map<size_t, list<node_ptr>> dependent_nodes;
traverse_nodes
(
p
,
[
&
](
node_ptr
node
)
{
for
(
auto
arg
:
node
->
arguments
())
for
(
auto
arg
:
node
->
get_
arguments
())
{
m_ss
<<
" "
<<
arg
->
node_id
()
<<
" -> "
<<
node
->
node_id
()
<<
";
\n
"
;
m_ss
<<
" "
<<
arg
->
get_node_id
()
<<
" -> "
<<
node
->
get_
node_id
()
<<
";
\n
"
;
}
});
}
...
...
src/ops/broadcast.cpp
View file @
7ed0fe7d
...
...
@@ -32,7 +32,7 @@ Node::ptr ngraph::op::broadcast(const Node::ptr& tensor,
void
BroadcastOp
::
propagate_types
()
{
auto
arg_type
=
m_arguments
.
at
(
0
)
->
value_type
();
auto
arg_type
=
m_arguments
.
at
(
0
)
->
get_
value_type
();
if
(
nullptr
==
arg_type
)
{
throw
ngraph_error
(
"Argument to broadcast is missing type."
);
...
...
@@ -47,11 +47,11 @@ void BroadcastOp::propagate_types()
{
target_shape
.
erase
(
target_shape
.
begin
()
+
*
i
);
}
if
(
Shape
{
target_shape
}
!=
arg_tensor_view_type
->
shape
())
if
(
Shape
{
target_shape
}
!=
arg_tensor_view_type
->
get_
shape
())
{
throw
ngraph_error
(
"Broadcast arg, shape, and axes are incompatible"
);
}
// TODO If m_type is already set (by framework), this should verify that the type
// we expect is consistent with the type the framework expects.
m_value_type
=
make_shared
<
TensorViewType
>
(
arg_tensor_view_type
->
element_type
(),
m_shape
);
m_value_type
=
make_shared
<
TensorViewType
>
(
arg_tensor_view_type
->
get_
element_type
(),
m_shape
);
}
src/ops/dot.cpp
View file @
7ed0fe7d
...
...
@@ -27,21 +27,21 @@ Node::ptr ngraph::op::dot(const Node::ptr& arg0, const Node::ptr& arg1)
void
DotOp
::
propagate_types
()
{
auto
arg0_tensor_type
=
dynamic_pointer_cast
<
TensorViewType
>
(
m_arguments
.
at
(
0
)
->
value_type
());
auto
arg1_tensor_type
=
dynamic_pointer_cast
<
TensorViewType
>
(
m_arguments
.
at
(
1
)
->
value_type
());
auto
arg0_tensor_type
=
dynamic_pointer_cast
<
TensorViewType
>
(
m_arguments
.
at
(
0
)
->
get_
value_type
());
auto
arg1_tensor_type
=
dynamic_pointer_cast
<
TensorViewType
>
(
m_arguments
.
at
(
1
)
->
get_
value_type
());
if
(
nullptr
==
arg0_tensor_type
||
nullptr
==
arg1_tensor_type
)
{
throw
ngraph_error
(
"Arguments to dot must be tensor views"
);
}
if
(
arg0_tensor_type
->
element_type
()
!=
arg1_tensor_type
->
element_type
())
if
(
arg0_tensor_type
->
get_element_type
()
!=
arg1_tensor_type
->
get_
element_type
())
{
throw
ngraph_error
(
"Arguments to dot must have the same element type"
);
}
// Use NumPy semantics for now
// Last axis of first arg reduces against second to last of second arg if more than one axis, else axis.
vector
<
size_t
>
arg0_shape
=
arg0_tensor_type
->
shape
();
vector
<
size_t
>
arg1_shape
=
arg1_tensor_type
->
shape
();
vector
<
size_t
>
arg0_shape
=
arg0_tensor_type
->
get_
shape
();
vector
<
size_t
>
arg1_shape
=
arg1_tensor_type
->
get_
shape
();
size_t
arg0_reduction
=
arg0_shape
.
size
()
-
1
;
size_t
arg1_reduction
;
if
(
arg1_shape
.
size
()
>
1
)
...
...
@@ -60,5 +60,5 @@ void DotOp::propagate_types()
copy
(
arg0_shape
.
begin
(),
arg0_shape
.
begin
()
+
arg1_reduction
,
result_shape
.
end
());
copy
(
arg1_shape
.
begin
(),
arg1_shape
.
begin
()
+
arg1_reduction
,
result_shape
.
end
());
copy
(
arg1_shape
.
begin
()
+
arg1_reduction
,
arg1_shape
.
end
(),
result_shape
.
end
());
m_value_type
=
make_shared
<
TensorViewType
>
(
arg0_tensor_type
->
element_type
(),
result_shape
);
m_value_type
=
make_shared
<
TensorViewType
>
(
arg0_tensor_type
->
get_
element_type
(),
result_shape
);
}
src/ops/op.cpp
View file @
7ed0fe7d
...
...
@@ -20,10 +20,10 @@
using
namespace
ngraph
;
using
namespace
std
;
std
::
string
ngraph
::
Op
::
node_id
()
const
std
::
string
ngraph
::
Op
::
get_
node_id
()
const
{
stringstream
ss
;
ss
<<
op_class_name
()
<<
"_"
<<
m_instance_id
;
ss
<<
get_
op_class_name
()
<<
"_"
<<
m_instance_id
;
return
ss
.
str
();
}
...
...
src/ops/parameter.cpp
View file @
7ed0fe7d
...
...
@@ -56,7 +56,7 @@ shared_ptr<Parameter> ngraph::op::parameter(const ngraph::element::Type element_
return
make_shared
<
Parameter
>
(
make_shared
<
TensorViewType
>
(
element_type
,
shape
));
}
std
::
string
ngraph
::
Parameter
::
node_id
()
const
std
::
string
ngraph
::
Parameter
::
get_
node_id
()
const
{
stringstream
ss
;
ss
<<
"parameter_"
<<
m_instance_id
;
...
...
src/types/type.cpp
View file @
7ed0fe7d
...
...
@@ -26,11 +26,11 @@ bool TensorViewType::operator==(const ValueType::ptr& that) const
{
return
false
;
}
if
(
that_tvt
->
element_type
()
!=
m_element_type
)
if
(
that_tvt
->
get_
element_type
()
!=
m_element_type
)
{
return
false
;
}
if
(
that_tvt
->
shape
()
!=
m_shape
)
if
(
that_tvt
->
get_
shape
()
!=
m_shape
)
{
return
false
;
}
...
...
@@ -44,5 +44,5 @@ bool TupleType::operator==(const ValueType::ptr& that) const
{
return
false
;
}
return
that_tvt
->
element_types
()
==
element_types
();
return
that_tvt
->
get_element_types
()
==
get_
element_types
();
}
src/util.cpp
View file @
7ed0fe7d
...
...
@@ -136,11 +136,11 @@ static void traverse_nodes(std::shared_ptr<ngraph::Node> p,
std
::
set
<
size_t
>&
instances_seen
)
{
f
(
p
);
for
(
auto
arg
:
p
->
arguments
())
for
(
auto
arg
:
p
->
get_
arguments
())
{
if
(
instances_seen
.
find
(
arg
->
instance_id
())
==
instances_seen
.
end
())
if
(
instances_seen
.
find
(
arg
->
get_
instance_id
())
==
instances_seen
.
end
())
{
instances_seen
.
insert
(
arg
->
instance_id
());
instances_seen
.
insert
(
arg
->
get_
instance_id
());
traverse_nodes
(
arg
,
f
,
instances_seen
);
}
}
...
...
test/build_graph.cpp
View file @
7ed0fe7d
...
...
@@ -30,8 +30,8 @@ TEST(build_graph, build_simple)
auto
broadcast_1
=
node
<
BroadcastOp
>
(
arg3
,
Shape
{
10
,
32
,
7
},
AxisSet
{
0
});
auto
b1
=
node
<
BroadcastOp
>
(
arg3
,
Shape
{
10
,
32
,
7
},
AxisSet
{
0
});
auto
dot
=
node
<
DotOp
>
(
arg2
,
arg0
);
ASSERT_EQ
(
dot
->
arguments
()[
0
],
arg2
);
ASSERT_EQ
(
dot
->
arguments
()[
1
],
arg0
);
ASSERT_EQ
(
dot
->
get_
arguments
()[
0
],
arg2
);
ASSERT_EQ
(
dot
->
get_
arguments
()[
1
],
arg0
);
auto
cluster_0
=
op
::
function
(
dot
,
{
arg0
,
arg1
,
arg2
,
arg3
});
...
...
@@ -80,22 +80,22 @@ TEST(build_graph, literal)
//auto float0 = FloatScalarConstant::make(3.0);
auto
float0
=
node
<
FloatScalarConstant
>
(
3.0
);
auto
float_scalar_type
=
make_shared
<
TensorViewType
>
(
element
::
Float
::
element_type
(),
Shape
{});
ASSERT_EQ
(
float0
->
value
(),
3.0
);
ASSERT_EQ
(
*
float0
->
value_type
(),
float_scalar_type
);
ASSERT_EQ
(
float0
->
get_
value
(),
3.0
);
ASSERT_EQ
(
*
float0
->
get_
value_type
(),
float_scalar_type
);
auto
d
=
node
<
DotOp
>
(
float0
,
float0
);
ASSERT_EQ
(
d
->
arguments
().
at
(
0
),
float0
);
ASSERT_EQ
(
d
->
arguments
().
at
(
1
),
float0
);
ASSERT_EQ
(
d
->
get_
arguments
().
at
(
0
),
float0
);
ASSERT_EQ
(
d
->
get_
arguments
().
at
(
1
),
float0
);
// float scalar from an int
auto
float1
=
node
<
FloatScalarConstant
>
(
3
);
ASSERT_EQ
(
float1
->
value
(),
3
);
ASSERT_EQ
(
*
float1
->
value_type
(),
float_scalar_type
);
ASSERT_EQ
(
float1
->
get_
value
(),
3
);
ASSERT_EQ
(
*
float1
->
get_
value_type
(),
float_scalar_type
);
auto
int32_0
=
node
<
Int32ScalarConstant
>
(
3.0
);
auto
int32_scalar_type
=
make_shared
<
TensorViewType
>
(
element
::
Int32
::
element_type
(),
Shape
{});
ASSERT_EQ
(
int32_0
->
value
(),
3
);
ASSERT_EQ
(
*
int32_0
->
value_type
(),
int32_scalar_type
);
ASSERT_NE
(
*
int32_0
->
value_type
(),
float_scalar_type
);
ASSERT_EQ
(
int32_0
->
get_
value
(),
3
);
ASSERT_EQ
(
*
int32_0
->
get_
value_type
(),
int32_scalar_type
);
ASSERT_NE
(
*
int32_0
->
get_
value_type
(),
float_scalar_type
);
}
// Check argument inverses
...
...
test/topological_sort.cpp
View file @
7ed0fe7d
...
...
@@ -42,7 +42,7 @@ TEST(top_sort, basic)
auto
f0
=
op
::
function
(
r0
,
{
arg0
,
arg1
});
ASSERT_NE
(
nullptr
,
f0
);
ASSERT_EQ
(
2
,
r0
->
arguments
().
size
());
ASSERT_EQ
(
2
,
r0
->
get_
arguments
().
size
());
auto
op_r0
=
static_pointer_cast
<
Op
>
(
r0
);
cout
<<
"op_r0 name "
<<
*
r0
<<
endl
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment