Commit 44b58722 authored by Nishant Patel's avatar Nishant Patel Committed by Sang Ik Lee

LRU cache for dynamic shapes (#3827)

* LRU cache for caching graphs

* static

* LRU Cache

* Cache graph

* Make data members static

* Temp test case functional

* Temp test case functional

* Fix shape

* Make lru local to dynamic wrapper

* Make cache thread safe

* Remove static copies of data members

* Add a separator  between shapes of different inputs

* Clear list and map in destructor

* Caching on values of shape relevant inputs

* Replace cout's by NGRAPH_INFO

* Add a environment variable for cache size

* Add mutex header

* style

* change to int64_t

* Save the cloned function to get the output shape to allocate output storage

* Pass inputs without wrapping

* Fix conv shape relevant inputs

* gcc 4.8 doesnt support ostring stream as a copyable object

* Pass key by reference

* PR feedback

* Apply suggestions from code review

* Replace malloc
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
Co-authored-by: 's avatarRobert Kimball <robert.kimball@intel.com>
parent 4614e4d9
......@@ -590,6 +590,8 @@ set (SRC
runtime/backend.hpp
runtime/backend_manager.cpp
runtime/backend_manager.hpp
runtime/cache.cpp
runtime/cache.hpp
runtime/chrome_trace.cpp
runtime/chrome_trace.hpp
runtime/executable.cpp
......
......@@ -519,8 +519,6 @@ void op::v1::ConvolutionBackpropFilters::validate_and_infer_types()
").");
}
set_input_is_relevant_to_shape(0);
set_input_is_relevant_to_shape(1);
set_input_is_relevant_to_shape(2);
set_output_type(0, forward_result_et, filters_shape);
}
......
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/runtime/cache.hpp"
using namespace ngraph;
using namespace std;
// Constructor
runtime::LRUCache::LRUCache()
{
char* cache_size = getenv("NGRAPH_CACHE_SIZE");
if (cache_size == nullptr)
{
m_cache_size = 1024; // TODO(nbpatel): Figure out a default size for the cache
}
else
{
m_cache_size = atoi(cache_size);
}
m_map = {};
m_list = {};
}
// Destructor
runtime::LRUCache::~LRUCache()
{
m_list.clear();
m_map.clear();
m_clone_function_map.clear();
}
void runtime::LRUCache::convert_shape_to_string(const vector<int>& shape, ostringstream& key)
{
if (!shape.empty())
{
std::copy(shape.begin(), shape.end(), std::ostream_iterator<int>(key, ", "));
}
}
void runtime::LRUCache::add_entry(const vector<int>& shape,
shared_ptr<runtime::Executable> exec,
shared_ptr<Function> func)
{
std::lock_guard<std::mutex> guard(m_mutex);
ostringstream key;
// check if the list is empty
if (m_list.size() == m_cache_size)
{
ostringstream key;
convert_shape_to_string(m_list.back(), key);
m_list.pop_back();
m_map.erase(key.str());
}
convert_shape_to_string(shape, key);
m_map.insert({key.str(), exec});
m_list.push_front(shape);
m_clone_function_map.insert({key.str(), func});
}
bool runtime::LRUCache::is_cached(const vector<int>& shape)
{
for (auto itr = m_list.begin(); itr != m_list.end(); itr++)
{
if (*itr == shape)
{
return true;
}
}
return false;
}
shared_ptr<runtime::Executable> runtime::LRUCache::get_cached_entry(const vector<int>& shape)
{
std::lock_guard<std::mutex> guard(m_mutex);
ostringstream key;
convert_shape_to_string(shape, key);
// find the entry and return the function
auto it = m_map.find(key.str());
if (it == m_map.end())
{
throw ngraph_error("Entry not found in cache");
}
else
{
// update list to push this reference to the front
for (auto itr = m_list.begin(); itr != m_list.end(); itr++)
{
if (*itr == shape)
{
m_list.remove(shape);
m_list.push_front(shape);
break;
}
}
return it->second;
}
}
// Need the clone function to get the output shape so that
// storage can be allocated for output
shared_ptr<Function> runtime::LRUCache::get_cloned_function(const vector<int>& shape)
{
std::lock_guard<std::mutex> guard(m_mutex);
ostringstream key;
convert_shape_to_string(shape, key);
// find the entry and return the function
auto it = m_clone_function_map.find(key.str());
if (it == m_clone_function_map.end())
{
throw ngraph_error("Cloned function not found");
}
return it->second;
}
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <algorithm>
#include <iostream>
#include <iterator>
#include <list>
#include <mutex>
#include <sstream>
#include <string>
#include <unordered_map>
#include "ngraph/function.hpp"
#include "ngraph/runtime/executable.hpp"
#include "ngraph/shape.hpp"
namespace ngraph
{
namespace runtime
{
class LRUCache : public std::enable_shared_from_this<LRUCache>
{
public:
using GraphCache = std::unordered_map<std::string, std::shared_ptr<Executable>>;
using ClonedFunctionMap = std::unordered_map<std::string, std::shared_ptr<Function>>;
LRUCache();
virtual ~LRUCache();
void add_entry(const std::vector<int>& shape,
std::shared_ptr<Executable> exec,
std::shared_ptr<Function> func);
bool is_cached(const std::vector<int>& shape);
std::shared_ptr<Executable> get_cached_entry(const std::vector<int>& shape);
void convert_shape_to_string(const std::vector<int>& shape, std::ostringstream& key);
std::shared_ptr<Function> get_cloned_function(const std::vector<int>& shape);
private:
int m_cache_size;
GraphCache m_map;
ClonedFunctionMap m_clone_function_map;
std::list<std::vector<int>> m_list;
std::mutex m_mutex;
};
}
}
......@@ -22,6 +22,7 @@
#include <vector>
#include "ngraph/runtime/backend.hpp"
#include "ngraph/runtime/cache.hpp"
#include "ngraph/runtime/host_tensor.hpp"
#include "ngraph/runtime/tensor.hpp"
......@@ -100,6 +101,8 @@ public:
private:
std::shared_ptr<ngraph::Function> m_wrapped_function;
std::shared_ptr<ngraph::runtime::Backend> m_wrapped_backend;
std::shared_ptr<ngraph::runtime::LRUCache> m_lru =
std::make_shared<ngraph::runtime::LRUCache>();
bool m_enable_performance_collection;
};
......
......@@ -1408,6 +1408,8 @@ NGRAPH_TEST(${BACKEND_NAME}, avg_pool_bprop_2d_2channel_2image_dyn_shape)
float denom = 2 * 2;
ex->call_with_validate({t_r}, {deltas, forward_shape});
ex->call_with_validate({t_r}, {deltas, forward_shape});
ex->call_with_validate({t_r}, {deltas, forward_shape});
ASSERT_EQ(t_r->get_shape(), (Shape{2, 2, 3, 3}));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment