diff --git a/src/map_memory.cpp b/src/map_memory.cpp index d3179b3653cc5..f09b3e29d2223 100644 --- a/src/map_memory.cpp +++ b/src/map_memory.cpp @@ -2,62 +2,80 @@ #include -void map_memory::trim( int limit ) +template +T lru_cache::get( const tripoint &pos, const T &default_ ) const { - while( tiles.size() > static_cast( limit ) ) { - tile_map.erase( tiles.back().first ); - tiles.pop_back(); + auto found = map.find( pos ); + if( found != map.end() ) { + return found->second->second; } - while( symbols.size() > static_cast( limit ) ) { - symbol_map.erase( symbols.back().first ); - symbols.pop_back(); + return default_; +} + +template +void lru_cache::insert( int limit, const tripoint &pos, const T &t ) +{ + auto found = map.find( pos ); + + if( found == map.end() ) { + // Need new entry in map. Make the new list entry and point to it. + ordered_list.emplace_back( pos, t ); + map[pos] = std::prev( ordered_list.end() ); + trim( limit ); + } else { + // Splice existing entry to the back. Does not invalidate the + // iterator, so no need to change the map. + auto list_iterator = found->second; + ordered_list.splice( ordered_list.end(), ordered_list, list_iterator ); + // Update the moved item + list_iterator->second = t; } } -memorized_terrain_tile map_memory::get_tile( const tripoint &pos ) const +template +void lru_cache::trim( int limit ) { - auto found_tile = tile_map.find( pos ); - if( found_tile != tile_map.end() ) { - return found_tile->second->second; + while( ordered_list.size() > static_cast( limit ) ) { + map.erase( ordered_list.front().first ); + ordered_list.pop_front(); } - return { "", 0, 0 }; +} + +template +void lru_cache::clear() +{ + map.clear(); + ordered_list.clear(); +} + +template +const std::list::Pair> &lru_cache::list() const +{ + return ordered_list; +} + +template class lru_cache; +template class lru_cache; + +static const memorized_terrain_tile default_tile{ "", 0, 0 }; + +memorized_terrain_tile map_memory::get_tile( const tripoint &pos ) const +{ + return tile_cache.get( pos, default_tile ); } void map_memory::memorize_tile( int limit, const tripoint &pos, const std::string &ter, const int subtile, const int rotation ) { - memorized_terrain_tile new_tile{ ter, subtile, rotation }; - tiles.push_front( std::make_pair( pos, new_tile ) ); - auto found_tile = tile_map.find( pos ); - if( found_tile != tile_map.end() ) { - // Remove redundant entry since we pushed one to the front. - tiles.erase( found_tile->second ); - found_tile->second = tiles.begin(); - } else { - tile_map[pos] = tiles.begin(); - trim( limit ); - } + tile_cache.insert( limit, pos, memorized_terrain_tile{ ter, subtile, rotation } ); } long map_memory::get_symbol( const tripoint &pos ) const { - auto found_tile = symbol_map.find( pos ); - if( found_tile != symbol_map.end() ) { - return found_tile->second->second; - } - return 0; + return symbol_cache.get( pos, 0 ); } void map_memory::memorize_symbol( int limit, const tripoint &pos, const long symbol ) { - symbols.emplace_front( pos, symbol ); - auto found_tile = symbol_map.find( pos ); - if( found_tile != symbol_map.end() ) { - // Remove redundant entry since we pushed on to the front. - symbols.erase( found_tile->second ); - found_tile->second = symbols.begin(); - } else { - symbol_map[pos] = symbols.begin(); - trim( limit ); - } + symbol_cache.insert( limit, pos, symbol ); } diff --git a/src/map_memory.h b/src/map_memory.h index 3dac6131cdbd4..beb26644b8d15 100644 --- a/src/map_memory.h +++ b/src/map_memory.h @@ -17,6 +17,26 @@ struct memorized_terrain_tile { int rotation; }; +template +class lru_cache +{ + public: + using Pair = std::pair; + + void insert( int limit, const tripoint &, const T & ); + T get( const tripoint &, const T &default_ ) const; + + void clear(); + const std::list &list() const; + private: + void trim( int limit ); + std::list ordered_list; + std::unordered_map::iterator> map; +}; + +extern template class lru_cache; +extern template class lru_cache; + class map_memory { public: @@ -33,12 +53,8 @@ class map_memory long get_symbol( const tripoint &p ) const; private: void trim( int limit ); - using tile_pair = std::pair; - std::list tiles; - std::unordered_map::iterator> tile_map; - using symbol_pair = std::pair; - std::list symbols; - std::unordered_map::iterator> symbol_map; + lru_cache tile_cache; + lru_cache symbol_cache; }; #endif diff --git a/src/savegame_json.cpp b/src/savegame_json.cpp index 5867ad4d524f1..724388bd4af24 100644 --- a/src/savegame_json.cpp +++ b/src/savegame_json.cpp @@ -2613,7 +2613,7 @@ void map_memory::store( JsonOut &jsout ) const { jsout.member( "map_memory_tiles" ); jsout.start_array(); - for( const auto &elem : tiles ) { + for( const auto &elem : tile_cache.list() ) { jsout.start_object(); jsout.member( "x", elem.first.x ); jsout.member( "y", elem.first.y ); @@ -2627,7 +2627,7 @@ void map_memory::store( JsonOut &jsout ) const jsout.member( "map_memory_curses" ); jsout.start_array(); - for( const auto &elem : symbols ) { + for( const auto &elem : symbol_cache.list() ) { jsout.start_object(); jsout.member( "x", elem.first.x ); jsout.member( "y", elem.first.y ); @@ -2641,8 +2641,7 @@ void map_memory::store( JsonOut &jsout ) const void map_memory::load( JsonObject &jsin ) { JsonArray map_memory_tiles = jsin.get_array( "map_memory_tiles" ); - tiles.clear(); - tile_map.clear(); + tile_cache.clear(); while( map_memory_tiles.has_more() ) { JsonObject pmap = map_memory_tiles.next_object(); const tripoint p( pmap.get_int( "x" ), pmap.get_int( "y" ), pmap.get_int( "z" ) ); @@ -2651,8 +2650,7 @@ void map_memory::load( JsonObject &jsin ) } JsonArray map_memory_curses = jsin.get_array( "map_memory_curses" ); - symbols.clear(); - symbol_map.clear(); + symbol_cache.clear(); while( map_memory_curses.has_more() ) { JsonObject pmap = map_memory_curses.next_object(); const tripoint p( pmap.get_int( "x" ), pmap.get_int( "y" ), pmap.get_int( "z" ) ); diff --git a/tests/map_memory.cpp b/tests/map_memory.cpp new file mode 100644 index 0000000000000..77856802c3798 --- /dev/null +++ b/tests/map_memory.cpp @@ -0,0 +1,85 @@ +#include "catch/catch.hpp" + +#include "map_memory.h" +#include "json.h" + +static const tripoint p1{ 0, 0, 1 }; +static const tripoint p2{ 0, 0, 2 }; +static const tripoint p3{ 0, 0, 3 }; + +TEST_CASE( "map_memory_defaults", "[map_memory]" ) +{ + map_memory memory; + CHECK( memory.get_symbol( p1 ) == 0 ); + memorized_terrain_tile default_tile = memory.get_tile( p1 ); + CHECK( default_tile.tile == "" ); + CHECK( default_tile.subtile == 0 ); + CHECK( default_tile.rotation == 0 ); +} + +TEST_CASE( "map_memory_remembers", "[map_memory]" ) +{ + map_memory memory; + memory.memorize_symbol( 2, p1, 1 ); + memory.memorize_symbol( 2, p2, 2 ); + CHECK( memory.get_symbol( p1 ) == 1 ); + CHECK( memory.get_symbol( p2 ) == 2 ); +} + +TEST_CASE( "map_memory_limited", "[map_memory]" ) +{ + map_memory memory; + memory.memorize_symbol( 2, p1, 1 ); + memory.memorize_symbol( 2, p2, 2 ); + memory.memorize_symbol( 2, p3, 3 ); + CHECK( memory.get_symbol( p1 ) == 0 ); +} + +TEST_CASE( "map_memory_overwrites", "[map_memory]" ) +{ + map_memory memory; + memory.memorize_symbol( 2, p1, 1 ); + memory.memorize_symbol( 2, p2, 2 ); + memory.memorize_symbol( 2, p2, 3 ); + CHECK( memory.get_symbol( p1 ) == 1 ); + CHECK( memory.get_symbol( p2 ) == 3 ); +} + +TEST_CASE( "map_memory_erases_lru", "[map_memory]" ) +{ + map_memory memory; + memory.memorize_symbol( 2, p1, 1 ); + memory.memorize_symbol( 2, p2, 2 ); + memory.memorize_symbol( 2, p1, 1 ); + memory.memorize_symbol( 2, p3, 3 ); + CHECK( memory.get_symbol( p1 ) == 1 ); + CHECK( memory.get_symbol( p2 ) == 0 ); + CHECK( memory.get_symbol( p3 ) == 3 ); +} + +TEST_CASE( "map_memory_survives_save_lod", "[map_memory]" ) +{ + map_memory memory; + memory.memorize_symbol( 2, p1, 1 ); + memory.memorize_symbol( 2, p2, 2 ); + + // Save and reload + std::ostringstream jsout_s; + JsonOut jsout( jsout_s ); + jsout.start_object( "m" ); + memory.store( jsout ); + jsout.end_object(); + + INFO( "Json was: " << jsout_s.str() ); + std::istringstream jsin_s( jsout_s.str() ); + JsonIn jsin( jsin_s ); + map_memory memory2; + JsonObject json = jsin.get_object(); + memory2.load( json ); + + memory.memorize_symbol( 2, p3, 3 ); + memory2.memorize_symbol( 2, p3, 3 ); + CHECK( memory.get_symbol( p1 ) == memory2.get_symbol( p1 ) ); + CHECK( memory.get_symbol( p2 ) == memory2.get_symbol( p2 ) ); + CHECK( memory.get_symbol( p3 ) == memory2.get_symbol( p3 ) ); +}