Skip to content

Commit

Permalink
[setting] Add 'mixerSeed' setting to enable spliting seed space betwe…
Browse files Browse the repository at this point in the history
…en training and test runs.

Added a new setting for DMLab so that seeds for training and test runs are separate. This happens internally (using full 64 bit seeds instead of the 32 bit ones exposed by the environment API). This prevents any unnecessary folding due to training.
  • Loading branch information
DeepMind Lab Team authored and tkoeppe committed May 1, 2018
1 parent 9154942 commit d817a04
Show file tree
Hide file tree
Showing 16 changed files with 131 additions and 38 deletions.
1 change: 1 addition & 0 deletions RELEASE_NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Added further pre-built maps, which removes the need for the expensive
:map_assets build step.
2. Allow game to be renderered with top-left as origin instead of bottom-left.
3. Add 'mixerSeed' setting to change behaviour of all random number generators.

## release-2018-02-07 February 2018 release

Expand Down
19 changes: 15 additions & 4 deletions deepmind/engine/context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,11 @@ static bool get_native_app(void* userdata) {
return static_cast<Context*>(userdata)->NativeApp();
}

static void set_mixer_seed(void* userdata, int v) {
return static_cast<Context*>(userdata)->SetMixerSeed(
static_cast<std::uint32_t>(v));
}

static void set_actions(void* userdata, double look_down_up,
double look_left_right, signed char move_back_forward,
signed char strafe_left_right, signed char crouch_jump,
Expand Down Expand Up @@ -480,7 +485,7 @@ lua::NResultsOr MapMakerModule(lua_State* L) {
LuaTextLevelMaker::CreateObject(
L, ctx->ExecutableRunfiles(), ctx->TempDirectory(),
ctx->UseLocalLevelCache(), ctx->UseGlobalLevelCache(),
ctx->LevelCacheParams());
ctx->LevelCacheParams(), ctx->MixerSeed());
return 1;
} else {
return "Missing context!";
Expand All @@ -507,6 +512,7 @@ Context::Context(lua::Vm lua_vm, const char* executable_runfiles,
: lua_vm_(std::move(lua_vm)),
native_app_(false),
actions_{},
mixer_seed_(0),
level_cache_params_{},
game_(executable_runfiles, calls, file_reader_override,
temp_folder != nullptr ? temp_folder : ""),
Expand All @@ -528,6 +534,7 @@ Context::Context(lua::Vm lua_vm, const char* executable_runfiles,
hooks->run_lua_snippet = run_lua_snippet;
hooks->set_native_app = set_native_app;
hooks->get_native_app = get_native_app;
hooks->set_mixer_seed = set_mixer_seed;
hooks->set_actions = set_actions;
hooks->get_actions = get_actions;
hooks->find_model = find_model;
Expand Down Expand Up @@ -652,7 +659,8 @@ int Context::Init() {
lua_vm_.AddCModuleToSearchers(
"dmlab.system.tensor", tensor::LuaTensorConstructors);
lua_vm_.AddCModuleToSearchers(
"dmlab.system.maze_generation", LuaMazeGeneration::Require);
"dmlab.system.maze_generation", &lua::Bind<LuaMazeGeneration::Require>,
{reinterpret_cast<void*>(static_cast<std::uintptr_t>(mixer_seed_))});
lua_vm_.AddCModuleToSearchers(
"dmlab.system.map_maker", &lua::Bind<MapMakerModule>, {this});
lua_vm_.AddCModuleToSearchers(
Expand All @@ -668,7 +676,9 @@ int Context::Init() {
&lua::Bind<ContextPickups::Module>,
{MutablePickups()});
lua_vm_.AddCModuleToSearchers(
"dmlab.system.random", &lua::Bind<LuaRandom::Require>, {UserPrbg()});
"dmlab.system.random", &lua::Bind<LuaRandom::Require>,
{UserPrbg(),
reinterpret_cast<void*>(static_cast<std::uintptr_t>(mixer_seed_))});
lua_vm_.AddCModuleToSearchers(
"dmlab.system.model", &lua::Bind<ModelModule>,
{const_cast<DeepmindCalls*>(Game().Calls())});
Expand Down Expand Up @@ -704,7 +714,8 @@ int Context::Init() {
}

int Context::Start(int episode, int seed) {
EnginePrbg()->seed(seed);
EnginePrbg()->seed(static_cast<std::uint64_t>(seed) ^
(static_cast<std::uint64_t>(mixer_seed_) << 32));
MutableGame()->NextMap();
lua_State* L = lua_vm_.get();
script_table_ref_.PushMemberFunction("start");
Expand Down
12 changes: 12 additions & 0 deletions deepmind/engine/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,15 @@ class Context {
// generate new positive integers.
int MakeRandomSeed();

// Specifies a mixer value to be combined with all the seeds passed to this
// environment, before using them with the internal PRBGs. This is done in
// a way which guarantees that the resulting seeds span disjoint subsets of
// the integers in [0, 2^64) for each different mixer value. However, the
// sequences produced by the environment's PRBGs are not necessarily disjoint.
void SetMixerSeed(std::uint32_t s) { mixer_seed_ = s; }

std::uint32_t MixerSeed() const { return mixer_seed_; }

std::mt19937_64* UserPrbg() { return &user_prbg_; }

std::mt19937_64* EnginePrbg() { return &engine_prbg_; }
Expand Down Expand Up @@ -480,6 +489,9 @@ class Context {
// A pseudo-random-bit generator for exclusive use by users.
std::mt19937_64 user_prbg_;

// Stores the mixer seed for the PRBG.
std::uint32_t mixer_seed_;

// A pseudo-random-bit generator for exclusive use of the engine. Seeded each
// episode with the episode start seed.
std::mt19937_64 engine_prbg_;
Expand Down
21 changes: 16 additions & 5 deletions deepmind/engine/lua_maze_generation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ namespace lab {
namespace {

std::mt19937_64* GetRandomNumberGenerator(lua::TableRef* table,
std::mt19937_64* seeded_rng) {
std::mt19937_64* seeded_rng,
std::uint64_t mixer_seq) {
std::mt19937_64* prng = nullptr;
lua_State* L = table->LuaState();
table->LookUpToStack("random");
Expand All @@ -51,7 +52,7 @@ std::mt19937_64* GetRandomNumberGenerator(lua::TableRef* table,
if (prng == nullptr) {
int seed = 0;
if (table->LookUp("seed", &seed)) {
seeded_rng->seed(seed);
seeded_rng->seed(static_cast<std::uint64_t>(seed) ^ mixer_seq);
prng = seeded_rng;
}
}
Expand Down Expand Up @@ -112,11 +113,19 @@ class LuaRoom : public lua::Class<LuaRoom> {
std::vector<maze_generation::Pos> room_;
};

// Bit toggle sequence applied to the 32 MSB of the 64bit seeds fed to the maze
// generation PRBGs, with the intention of creating disjoint seed subspaces for
// each different mixer_seed value as described in python_api.md
std::uint64_t LuaMazeGeneration::mixer_seq_ = 0;

const char* LuaMazeGeneration::ClassName() {
return "deepmind.lab.LuaMazeGeneration";
}

int LuaMazeGeneration::Require(lua_State* L) {
lua::NResultsOr LuaMazeGeneration::Require(lua_State* L) {
std::uintptr_t mixer_seed =
reinterpret_cast<std::uintptr_t>(lua_touserdata(L, lua_upvalueindex(1)));
mixer_seq_ = static_cast<std::uint64_t>(mixer_seed) << 32;
auto table = lua::TableRef::Create(L);
table.Insert("mazeGeneration", &lua::Bind<LuaMazeGeneration::Create>);
table.Insert("randomMazeGeneration",
Expand Down Expand Up @@ -162,7 +171,8 @@ lua::NResultsOr LuaMazeGeneration::CreateRandom(lua_State* L) {
lua::Read(L, -1, &table);

std::mt19937_64 seeded_rng;
std::mt19937_64* prng = GetRandomNumberGenerator(&table, &seeded_rng);
std::mt19937_64* prng =
GetRandomNumberGenerator(&table, &seeded_rng, mixer_seq_);
if (prng == nullptr) {
return "[randomMazeGeneration] - Must construct with 'random' a random "
"number generator. ('seed' is deprecated.)";
Expand Down Expand Up @@ -544,7 +554,8 @@ lua::NResultsOr LuaMazeGeneration::VisitRandomPath(lua_State* L) {
return "[visitRandomPath] - must supply table";
}
std::mt19937_64 seeded_rng;
std::mt19937_64* prng = GetRandomNumberGenerator(&table, &seeded_rng);
std::mt19937_64* prng =
GetRandomNumberGenerator(&table, &seeded_rng, mixer_seq_);
if (prng == nullptr) {
return "[visitRandomPath] - must supply 'random' with random number "
"generator. ('seed' is deprecated.)";
Expand Down
6 changes: 4 additions & 2 deletions deepmind/engine/lua_maze_generation.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (C) 2016 Google Inc.
// Copyright (C) 2016-2018 Google Inc.
//
// This program is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
Expand Down Expand Up @@ -47,7 +47,7 @@ class LuaMazeGeneration : public lua::Class<LuaMazeGeneration> {

// Returns table of constructors and standalone functions.
// [0, 1, -]
static int Require(lua_State* L);
static lua::NResultsOr Require(lua_State* L);

private:
// Constructs a LuaMazeGeneration.
Expand Down Expand Up @@ -197,6 +197,8 @@ class LuaMazeGeneration : public lua::Class<LuaMazeGeneration> {
lua::NResultsOr CountVariations(lua_State* L);

maze_generation::TextMaze text_maze_;

static std::uint64_t mixer_seq_;
};

} // namespace lab
Expand Down
10 changes: 6 additions & 4 deletions deepmind/engine/lua_maze_generation_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,13 @@ class LuaMazeGenerationTest : public lua::testing::TestWithVm {
protected:
LuaMazeGenerationTest() {
LuaMazeGeneration::Register(L);
vm()->AddCModuleToSearchers("dmlab.system.maze_generation",
LuaMazeGeneration::Require);
vm()->AddCModuleToSearchers(
"dmlab.system.maze_generation", &lua::Bind<LuaMazeGeneration::Require>,
{reinterpret_cast<void*>(static_cast<std::uintptr_t>(0))});
LuaRandom::Register(L);
vm()->AddCModuleToSearchers("dmlab.system.sys_random",
&lua::Bind<LuaRandom::Require>, {&prbg_});
vm()->AddCModuleToSearchers(
"dmlab.system.sys_random", &lua::Bind<LuaRandom::Require>,
{&prbg_, reinterpret_cast<void*>(static_cast<std::uintptr_t>(0))});
}

std::mt19937_64 prbg_;
Expand Down
8 changes: 5 additions & 3 deletions deepmind/engine/lua_random.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ bool ReadLargeNumber(lua_State* L, int idx, RbgNumType* num) {
lua::NResultsOr LuaRandom::Require(lua_State* L) {
if (auto* prbg = static_cast<std::mt19937_64*>(
lua_touserdata(L, lua_upvalueindex(1)))) {
LuaRandom::CreateObject(L, prbg);
std::uintptr_t mixer_seed = reinterpret_cast<std::uintptr_t>(
lua_touserdata(L, lua_upvalueindex(2)));
LuaRandom::CreateObject(L, prbg, mixer_seed);
return 1;
} else {
return "Missing std::mt19937_64 pointer in up value!";
Expand All @@ -87,7 +89,7 @@ lua::NResultsOr LuaRandom::Seed(lua_State* L) {
RbgNumType k;

if (ReadLargeNumber(L, -1, &k)) {
prbg_->seed(k);
prbg_->seed(k ^ mixer_seq_);
return 0;
} else if (lua::Read(L, -1, &s)) {
auto& err = errno; // cache TLS-lookup
Expand All @@ -96,7 +98,7 @@ lua::NResultsOr LuaRandom::Seed(lua_State* L) {
unsigned long long int n = std::strtoull(s.data(), &ep, 0);
if (ep != s.data() && *ep == '\0' && err == 0 &&
n <= std::numeric_limits<RbgNumType>::max()) {
prbg_->seed(n);
prbg_->seed(n ^ mixer_seq_);
return 0;
}
}
Expand Down
4 changes: 3 additions & 1 deletion deepmind/engine/lua_random.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ class LuaRandom : public lua::Class<LuaRandom> {

public:
// Constructed with a non-owning view of a PRBG instance.
explicit LuaRandom(std::mt19937_64* prbg) : prbg_(prbg) {}
explicit LuaRandom(std::mt19937_64* prbg, std::uint32_t mixer_seed)
: prbg_(prbg), mixer_seq_(static_cast<std::uint64_t>(mixer_seed) << 32) {}

// Registers the class as well as member functions:
//
Expand Down Expand Up @@ -124,6 +125,7 @@ class LuaRandom : public lua::Class<LuaRandom> {

private:
std::mt19937_64* prbg_;
std::uint64_t mixer_seq_;
};

} // namespace lab
Expand Down
5 changes: 3 additions & 2 deletions deepmind/engine/lua_random_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ class LuaRandomTest : public lua::testing::TestWithVm {
protected:
LuaRandomTest() {
LuaRandom::Register(L);
vm()->AddCModuleToSearchers("dmlab.system.sys_random",
&lua::Bind<LuaRandom::Require>, {&prbg_});
vm()->AddCModuleToSearchers(
"dmlab.system.sys_random", &lua::Bind<LuaRandom::Require>,
{&prbg_, reinterpret_cast<void*>(static_cast<std::uintptr_t>(0))});
}

std::mt19937_64 prbg_;
Expand Down
9 changes: 6 additions & 3 deletions deepmind/engine/lua_text_level_maker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -314,8 +314,11 @@ bool NoOp(std::size_t, std::size_t, char,
LuaTextLevelMaker::LuaTextLevelMaker(
const std::string& self, const std::string& output_folder,
bool use_local_level_cache, bool use_global_level_cache,
DeepMindLabLevelCacheParams level_cache_params)
: prng_(0), rundir_(self), output_folder_(output_folder) {
DeepMindLabLevelCacheParams level_cache_params, std::uint32_t mixer_seed)
: prng_(0),
mixer_seed_(mixer_seed),
rundir_(self),
output_folder_(output_folder) {
settings_.use_local_level_cache = use_local_level_cache;
settings_.use_global_level_cache = use_global_level_cache;
settings_.level_cache_params = level_cache_params;
Expand Down Expand Up @@ -404,7 +407,7 @@ lua::NResultsOr LuaTextLevelMaker::MapFromTextLevel(lua_State* L) {


lua::NResultsOr LuaTextLevelMaker::ViewRandomness(lua_State* L) {
LuaRandom::CreateObject(L, &prng_);
LuaRandom::CreateObject(L, &prng_, mixer_seed_);
return 1;
}

Expand Down
4 changes: 3 additions & 1 deletion deepmind/engine/lua_text_level_maker.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ class LuaTextLevelMaker : public lua::Class<LuaTextLevelMaker> {
const std::string& output_folder,
bool use_local_level_cache,
bool use_global_level_cache,
DeepMindLabLevelCacheParams level_cache_params);
DeepMindLabLevelCacheParams level_cache_params,
std::uint32_t mixer_seed);

// Registers MapFromTextLevel as "mapFromTextLevel".
static void Register(lua_State* L);
Expand Down Expand Up @@ -85,6 +86,7 @@ class LuaTextLevelMaker : public lua::Class<LuaTextLevelMaker> {

private:
std::mt19937_64 prng_;
std::uint32_t mixer_seed_;
MapCompileSettings settings_;
const std::string rundir_;
const std::string output_folder_;
Expand Down
4 changes: 4 additions & 0 deletions deepmind/include/deepmind_hooks.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,10 @@ struct DeepmindHooks_s {
// allowed to set actions.
bool (*get_native_app)(void* userdata);

// This is a bit toggle sequence applied to the most significant bits of the
// seed.
void (*set_mixer_seed)(void* userdata, int v);

// Sets the actions of the player.
void (*set_actions)(void* userdata, //
double look_down_up, //
Expand Down
6 changes: 4 additions & 2 deletions deepmind/tensor/lua_tensor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,15 @@ class LuaTensorTest : public ::testing::Test {
LuaTensorTest() : lua_vm_(lua::CreateVm()) {
auto* L = lua_vm_.get();
LuaRandom::Register(L);
lua_vm_.AddCModuleToSearchers("dmlab.system.sys_random",
&lua::Bind<LuaRandom::Require>, {&prbg_});
lua_vm_.AddCModuleToSearchers(
"dmlab.system.sys_random", &lua::Bind<LuaRandom::Require>,
{&prbg_, reinterpret_cast<void*>(static_cast<std::uintptr_t>(0))});
tensor::LuaTensorRegister(L);
lua_vm_.AddCModuleToSearchers("dmlab.system.tensor",
tensor::LuaTensorConstructors);
}
std::mt19937_64 prbg_;
uint32_t mixer_seed_;
lua::Vm lua_vm_;
};

Expand Down
23 changes: 15 additions & 8 deletions docs/users/python_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,15 @@ name in the list *observations*.
The `config` dict specifies additional settings as key-value string pairs. The
following options are recognized:

| Option | Description | Default |
| ---------------- | ----------------------------------------------- | ------: |
| `width` | horizontal resolution of the observation frames | `'320'` |
| `height` | vertical resolution of the observation frames | `'240'` |
| `fps` | frames per second | `'60'` |
| `levelDirectory` | optional path to level directory (relative | `''` |
: : paths are relative to game_scripts/levels) : :
| `appendCommand` | Commands for the internal Quake console\* | `''` |
| Option | Description | Default |
| ---------------- | ---------------------------------------------------------------------------------------------- | ------: |
| `width` | horizontal resolution of the observation frames | `'320'` |
| `height` | vertical resolution of the observation frames | `'240'` |
| `fps` | frames per second | `'60'` |
| `levelDirectory` | optional path to level directory (relative | `''` |
: : paths are relative to game_scripts/levels) : :
| `appendCommand` | Commands for the internal Quake console\* | `''` |
| `mixerSeed` | value combined with each of the seeds fed to the environment to define unique subsets of seeds | `'0'` |

\* See also [Lua map API](/docs/developers/reference/lua_api.md#commandlineold-commandline-string).

Expand Down Expand Up @@ -88,6 +89,12 @@ The optional integer argument `seed` can be supplied to seed the environment's
random number generator. If `seed` is omitted or `None`, a random number is
used.

The optional integer argument `mixerSeed` provided with the environment is
combined with every seed passed to this function. The resulting seeds span a
unique subset of the integers in \[0, 2^64\) for each different `mixerSeed`
value. However, the sequences produced by the environment's random number
generator are not necessarily disjoint.

### `num_steps`()

Number of frames since the last `reset`() call
Expand Down
12 changes: 12 additions & 0 deletions engine/code/deepmind/dmlab_connect.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
////////////////////////////////////////////////////////////////////////////////

#include <errno.h>
#include <inttypes.h>
#include <limits.h>
#include <math.h>
#include <stdbool.h>
Expand Down Expand Up @@ -636,6 +637,17 @@ static int dmlab_setting(void* context, const char* key, const char* value) {
}
Q_strcat(gc->command_line, sizeof(gc->command_line),
va(" +set name \"%s\"", value));
} else if (strcmp(key, "mixerSeed") == 0) {
int res = parse_int(value, &v, ctx);
if (res != 0) return res;
if (v < 0 || v > UINT32_MAX) {
ctx->hooks.set_error_message(ctx->userdata,
va("Invalid mixerSeed value, must be a "
"positive integer not greater than '%"
PRIu32 "'.", UINT32_MAX));
return 1;
}
ctx->hooks.set_mixer_seed(ctx->userdata, (uint32_t)v);
} else {
ctx->hooks.add_setting(ctx->userdata, key, value);
}
Expand Down
Loading

0 comments on commit d817a04

Please sign in to comment.