Unverified Commit 484c0a0d authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

add method to Function to get the size of a graph (#2367)

parent 122754c1
......@@ -191,3 +191,26 @@ void Function::replace_node(std::shared_ptr<Node> old, std::shared_ptr<Node> rep
{
ngraph::replace_node(old, repl);
}
size_t Function::get_graph_size() const
{
size_t total_size = 0;
for (auto node : get_ops())
{
total_size += sizeof(*node);
if (node->description() == "Constant")
{
const Shape& shape = node->get_outputs()[0].get_shape();
size_t const_size = node->get_outputs()[0].get_element_type().size();
if (shape.size() == 0)
{
total_size += const_size;
}
else
{
total_size += (const_size * shape_size(node->get_outputs()[0].get_shape()));
}
}
}
return total_size;
}
......@@ -85,6 +85,11 @@ namespace ngraph
void validate_nodes_and_infer_types();
/// \brief Returns the sum of the size of all nodes in the graph plus the size of
/// all constant data. This has little value beyond comparing the relative size of
/// graphs and should not be considered the actual memory consumption of a graph.
size_t get_graph_size() const;
protected:
ResultVector m_results;
ParameterVector m_parameters;
......
......@@ -32,6 +32,7 @@ set(SRC
control_dependencies.cpp
coordinate.cpp
copy.cpp
core.cpp
cpio.cpp
cse.cpp
element_type.cpp
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "gtest/gtest.h"
#include "ngraph/file_util.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/serializer.hpp"
using namespace ngraph;
using namespace std;
TEST(core, function_size)
{
const string m1 = file_util::path_join(SERIALIZED_ZOO, "mxnet/mnist_mlp_forward.json");
const string m2 = file_util::path_join(SERIALIZED_ZOO, "mxnet/10_bucket_LSTM.json");
auto f1 = deserialize(m1);
auto f2 = deserialize(m2);
auto s1 = f1->get_graph_size();
auto s2 = f2->get_graph_size();
EXPECT_GT(s2, s1);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment