From 73faf1814759d0ffb762ef16eda793b29a8433f5 Mon Sep 17 00:00:00 2001 From: Mikel Date: Tue, 11 Jun 2024 11:13:58 +0200 Subject: [PATCH] Initial craftium commit --- .gitignore | 3 + craftium-docs/docs/index.md | 17 +++ craftium-docs/mkdocs.yml | 6 + craftium/__init__.py | 1 + craftium/env.py | 157 ++++++++++++++++++++++++ craftium/minetest.py | 130 ++++++++++++++++++++ craftium/mt_client.py | 49 ++++++++ src/client/client.cpp | 203 ++++++++++++++++++++++++++++++++ src/client/client.h | 9 ++ src/client/game.cpp | 26 +++- src/client/inputhandler.cpp | 20 ++-- src/client/inputhandler.h | 22 +++- src/client/sync.h | 165 ++++++++++++++++++++++++++ src/script/scripting_server.cpp | 6 + src/server.cpp | 4 + test.py | 37 ++++++ wrapper.py | 83 +++++++++++++ 17 files changed, 921 insertions(+), 17 deletions(-) create mode 100644 craftium-docs/docs/index.md create mode 100644 craftium-docs/mkdocs.yml create mode 100644 craftium/__init__.py create mode 100644 craftium/env.py create mode 100644 craftium/minetest.py create mode 100644 craftium/mt_client.py create mode 100644 src/client/sync.h create mode 100644 test.py create mode 100644 wrapper.py diff --git a/.gitignore b/.gitignore index 37e27bfed..d37564977 100644 --- a/.gitignore +++ b/.gitignore @@ -124,3 +124,6 @@ lib/irrlichtmt # Generated mod storage database client/mod_storage.sqlite + +## Craftium +__pycache__ \ No newline at end of file diff --git a/craftium-docs/docs/index.md b/craftium-docs/docs/index.md new file mode 100644 index 000000000..8e0604f12 --- /dev/null +++ b/craftium-docs/docs/index.md @@ -0,0 +1,17 @@ +# Craftium + +Craftium is a fully open-source research platform for Reinforcement Learning (RL) research. Craftium provides a [Gymnasium](https://gymnasium.farama.org/index.html) wrapper for the [Minetest](https://www.minetest.net/) voxel game engine. + +## Commands + +* `mkdocs new [dir-name]` - Create a new project. +* `mkdocs serve` - Start the live-reloading docs server. +* `mkdocs build` - Build the documentation site. +* `mkdocs -h` - Print help message and exit. + +## Project layout + + mkdocs.yml # The configuration file. + docs/ + index.md # The documentation homepage. + ... # Other markdown pages, images and other files. diff --git a/craftium-docs/mkdocs.yml b/craftium-docs/mkdocs.yml new file mode 100644 index 000000000..01f706f04 --- /dev/null +++ b/craftium-docs/mkdocs.yml @@ -0,0 +1,6 @@ +site_name: Craftium + +nav: + - Home: index.md + +theme: readthedocs diff --git a/craftium/__init__.py b/craftium/__init__.py new file mode 100644 index 000000000..b2c53d46c --- /dev/null +++ b/craftium/__init__.py @@ -0,0 +1 @@ +from .env import CraftiumEnv diff --git a/craftium/env.py b/craftium/env.py new file mode 100644 index 000000000..e842bd135 --- /dev/null +++ b/craftium/env.py @@ -0,0 +1,157 @@ +import os +from typing import Optional +import time + +from .mt_client import MtClient +from .minetest import Minetest + +import numpy as np + +# import gymnasium as gym +from gymnasium import Env +from gymnasium.spaces import Dict, Discrete, Box + + +class CraftiumEnv(Env): + metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 30} + + def __init__( + self, + obs_width: int = 640, + obs_height: int = 360, + init_frames: int = 15, + render_mode: Optional[str] = None, + max_timesteps: Optional[int] = None, + run_dir: Optional[os.PathLike] = None, + ): + super(CraftiumEnv, self).__init__() + + self.obs_width = obs_width + self.obs_height = obs_height + self.init_frames = init_frames + self.max_timesteps = max_timesteps + + self.action_space = Dict({ + "forward": Discrete(2), + "backward": Discrete(2), + "left": Discrete(2), + "right": Discrete(2), + "jump": Discrete(2), + "aux1": Discrete(2), + "sneak": Discrete(2), + "zoom": Discrete(2), + "dig": Discrete(2), + "place": Discrete(2), + "drop": Discrete(2), + "inventory": Discrete(2), + "slot_1": Discrete(2), + "slot_2": Discrete(2), + "slot_3": Discrete(2), + "slot_4": Discrete(2), + "slot_5": Discrete(2), + "slot_6": Discrete(2), + "slot_7": Discrete(2), + "slot_8": Discrete(2), + "slot_9": Discrete(2), + "mouse": Box(low=-1, high=1, shape=(2,), dtype=np.float32), + }) + + # names of the actions in the order they must be sent to MT + self.action_order = [ + "forward", "backward", "left", "right", "jump", "aux1", "sneak", + "zoom", "dig", "place", "drop", "inventory", "slot_1", "slot_2", + "slot_3", "slot_4", "slot_5", "slot_6", "slot_7", "slot_8", "slot_9", + ] + + self.observation_space = Box(low=0, high=255, shape=(obs_width, obs_height, 3)) + + assert render_mode is None or render_mode in self.metadata["render_modes"] + self.render_mode = render_mode + + # handles the MT configuration and process + self.mt = Minetest( + run_dir=run_dir, + headless=render_mode != "human", + ) + + # variable initialized in the `reset` method + self.client = None # client that connects to minetest + + self.last_observation = None # used in render if "rgb_array" + self.timesteps = 0 # the timesteps counter + + def _get_info(self): + return dict() + + def reset( + self, + *, + seed: Optional[int] = None, + options: Optional[dict] = None, + ): + super().reset(seed=seed) + self.timesteps = 0 + + # kill the active mt process and the python client if any + if self.client is not None: + self.client.close() + self.mt.kill_process() + + # start the new MT process + self.mt.start_process() # launch the new MT process + time.sleep(2) # wait for MT to initialize (TODO Improve this) + + # connect the client to the MT process + self.client = MtClient( + img_width=self.obs_width, + img_height=self.obs_height, + ) + + # HACK skip some frames to let the game initialize + for _ in range(self.init_frames): + _observation, _reward = self.client.receive() + self.client.send([0]*21, 0, 0) # nop action + + observation, _reward = self.client.receive() + self.last_observation = observation + + info = self._get_info() + + return observation, info + + def step(self, action): + self.timesteps += 1 + + # convert the action dict to a format to be sent to MT through mt_client + keys = [0]*21 # all commands (keys) except the mouse + mouse_x, mouse_y = 0, 0 + for k, v in action.items(): + if k == "mouse": + x, y = v[0], v[1] + mouse_x = int(x*(self.obs_width // 2)) + mouse_y = int(y*(self.obs_height // 2)) + else: + keys[self.action_order.index(k)] = v + # send the action to MT + self.client.send(keys, mouse_x, mouse_y) + + # receive the new info from minetest + observation, reward = self.client.receive() + self.last_observation = observation + + info = self._get_info() + + # TODO Get the real termination info + terminated = False + truncated = self.max_timesteps is not None and self.timesteps >= self.max_timesteps + + return observation, reward, terminated, truncated, info + + def render(self): + if self.render_mode == "rgb_array": + return self.last_observation + + def close(self): + self.mt.kill_process() + self.mt.clear() + self.client.close() diff --git a/craftium/minetest.py b/craftium/minetest.py new file mode 100644 index 000000000..ab385ad33 --- /dev/null +++ b/craftium/minetest.py @@ -0,0 +1,130 @@ +import os +from typing import Optional, Any +import subprocess +import multiprocessing +from uuid import uuid4 +import shutil +from distutils.dir_util import copy_tree + + +def launch_process(cmd: str, cwd: Optional[os.PathLike] = None): + def launch_fn(): + stderr = open(os.path.join(cwd, "stderr.txt"), "w") + stdout = open(os.path.join(cwd, "stdout.txt"), "w") + subprocess.run(cmd, cwd=cwd, stderr=stderr, stdout=stdout) + process = multiprocessing.Process(target=launch_fn, args=[]) + process.start() + return process + + +class Minetest(): + def __init__( + self, + run_dir: Optional[os.PathLike] = None, + run_dir_prefix: Optional[os.PathLike] = None, + headless: bool = False, + seed: Optional[int] = None, + ): + # create a dedicated directory for this run + if run_dir is None: + self.run_dir = f"./minetest-run-{uuid4()}" + if run_dir_prefix is not None: + self.run_dir = os.path.join(run_dir_prefix, self.run_dir) + else: + self.run_dir = run_dir + # delete the directory if it already exists + if os.path.exists(self.run_dir): + shutil.rmtree(self.run_dir) + # create the directory + os.mkdir(self.run_dir) + + print(f"==> Creating Minetest run directory: {self.run_dir}") + + config = dict( + # Base config + enable_sound=False, + show_debug=False, + enable_client_modding=True, + csm_restriction_flags=0, + enable_mod_channels=True, + screen_w=640, + screen_h=360, + vsync=False, + fps_max=1000, + fps_max_unfocused=1000, + undersampling=1000, + # fov=self.fov_y, + # game_dir=self.game_dir, + + # Adapt HUD size to display size, based on (1024, 600) default + # hud_scaling=self.display_size[0] / 1024, + + # Attempt to improve performance. Impact unclear. + server_map_save_interval=1000000, + profiler_print_interval=0, + active_block_range=2, + abm_time_budget=0.01, + abm_interval=0.1, + active_block_mgmt_interval=4.0, + server_unload_unused_data_timeout=1000000, + client_unload_unused_data_timeout=1000000, + full_block_send_enable_min_time_from_building=0.0, + max_block_send_distance=100, + max_block_generate_distance=100, + num_emerge_threads=0, + emergequeue_limit_total=1000000, + emergequeue_limit_diskonly=1000000, + emergequeue_limit_generate=1000000, + ) + if seed is not None: + config["fixed_map_seed"] = seed + + self._write_config(config, os.path.join(self.run_dir, "minetest.conf")) + + # get the path location of the parent of this module (where all the minetest stuff is located) + root_path = os.path.dirname(os.path.dirname(__file__)) + + # create the directory tree structure needed by minetest + self._create_mt_dirs(root_dir=root_path, target_dir=self.run_dir) + + self.launch_cmd = ["./bin/minetest", "--go"] + + # set the env. variables to execute mintest in headless mode + if headless: + os.environ["SDL_VIDEODRIVER"] = "offscreen" + + self.proc = None + + def start_process(self): + self.proc = launch_process(self.launch_cmd, self.run_dir) + + def kill_process(self): + self.proc.terminate() + + def clear(self): + # delete the run's directory + if os.path.exists(self.run_dir): + shutil.rmtree(self.run_dir) + + def _write_config(self, config: dict[str, Any], path: os.PathLike): + with open(path, "w") as f: + for key, value in config.items(): + f.write(f"{key} = {value}\n") + + def _create_mt_dirs(self, root_dir: os.PathLike, target_dir: os.PathLike): + def link_dir(name): + os.symlink(os.path.join(root_dir, name), + os.path.join(target_dir, name)) + def copy_dir(name): + copy_tree(os.path.join(root_dir, name), + os.path.join(target_dir, name)) + + link_dir("builtin") + link_dir("fonts") + link_dir("locale") + link_dir("textures") + link_dir("bin") + + copy_dir("worlds") + copy_dir("games") + copy_dir("client") diff --git a/craftium/mt_client.py b/craftium/mt_client.py new file mode 100644 index 000000000..870b4b97d --- /dev/null +++ b/craftium/mt_client.py @@ -0,0 +1,49 @@ +import socket +import struct + +import numpy as np + +MT_IP = "127.0.0.1" +MT_PORT = 4343 + +class MtClient(): + def __init__(self, img_width: int, img_height: int): + self.img_width = img_width + self.img_height = img_height + + self.s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.s.connect((MT_IP, MT_PORT)) + + # pre-compute the number of bytes that we should receive from MT + self.rec_bytes = img_width*img_height*3 + 8 # the RGB image + 8 bytes of the reward + + def receive(self): + data = [] + while len(data) < self.rec_bytes: + data += self.s.recv(self.rec_bytes) + + # decode the reward value + reward_bytes = bytes(data[-8:]) # the last 8 bytes + # uncpack the double (float in python) in native endianess + reward = struct.unpack("d", bytes(reward_bytes))[0] + + # decode the observation RGB image + data = data[:-8] # get the image data, all bytes except the last 8 + # reshape received bytes into an image + img = np.fromiter( + data, + dtype=np.uint8, + count=(self.rec_bytes-8) + ).reshape(self.img_width, self.img_height, 3) + + return img, reward + + def send(self, keys: list[int], mouse_x: int, mouse_y: int): + assert len(keys) == 21, f"Keys list must be of length 21 and is {len(keys)}" + + mouse = list(struct.pack("getU16("server_map_save_interval"); m_mesh_grid = { g_settings->getU16("client_mesh_chunk") }; + + startPyServer(); +} + +void Client::startPyServer() +{ + // Creating socket file descriptor + if ( (pyserv_sockfd = socket(AF_INET, SOCK_STREAM, 0)) < 0 ) { + perror("[ERROR] Obs. server socket creation failed"); + exit(EXIT_FAILURE); + } + + pyserv_servaddr = (struct sockaddr_in*) malloc(sizeof(struct sockaddr_in)); + pyserv_cliaddr = (struct sockaddr_in*) malloc(sizeof(struct sockaddr_in)); + + memset(pyserv_servaddr, 0, sizeof(*pyserv_servaddr)); + memset(pyserv_cliaddr, 0, sizeof(*pyserv_cliaddr)); + + pyserv_servaddr->sin_family = AF_INET; // IPv4 + pyserv_servaddr->sin_addr.s_addr = INADDR_ANY; + pyserv_servaddr->sin_port = htons(pyserv_port); + + // Bind the socket with the server address + if (bind(pyserv_sockfd, + (const struct sockaddr *)pyserv_servaddr, + sizeof(*pyserv_servaddr)) < 0) + { + perror("[ERROR] Obs. server bind failed"); + exit(EXIT_FAILURE); + } + + // Now server is ready to listen and verification + if ((listen(pyserv_sockfd, 5)) != 0) { + printf("[ERROR] Obs. server listen failed...\n"); + exit(EXIT_FAILURE); + } + else + printf("[INFO] Obs. server listening...\n"); + + // Accept the data packet from client and verification + socklen_t len = sizeof(*pyserv_cliaddr); + pyserv_conn = accept(pyserv_sockfd, (struct sockaddr*)pyserv_cliaddr, &len); + if (pyserv_conn < 0) { + printf("[ERROR] Obs. server accept failed...\n"); + exit(EXIT_FAILURE); + } + else + printf("[INFO] Obs. server accepted the client\n"); + + // Set receive and send timeout on the socket + struct timeval timeout; + timeout.tv_sec = 2; /* timeout time in seconds */ + timeout.tv_usec = 0; + if (setsockopt(pyserv_conn, SOL_SOCKET, SO_RCVTIMEO, (const char*)&timeout, sizeof(timeout)) < 0) { + printf("[ERROR] setsockopt failed\n"); + exit(EXIT_FAILURE); + } + if (setsockopt(pyserv_conn, SOL_SOCKET, SO_SNDTIMEO, (const char*)&timeout, sizeof(timeout)) < 0) { + printf("[ERROR] setsockopt failed\n"); + exit(EXIT_FAILURE); + } + + printf("\n[INFO] Obs. server started in port %d\n\n", pyserv_port); +} + +void Client::pyServerListener() { + char actions[25]; + int n_send, n_recv, W, H, obs_rwd_buffer_size; + u32 c; // stores the RGBA pixel color + + /* Take the screenshot */ + irr::video::IVideoDriver *driver = m_rendering_engine->get_video_driver(); + irr::video::IImage* const raw_image = driver->createScreenShot(); + + /* Get the dimensions of the image */ + auto dims = raw_image->getDimension(); + W = dims.Width; + H = dims.Height; + + /* W*H*3 for the WxH RGB image and +8 for the reward value (a double) */ + obs_rwd_buffer_size = W*H*3 + 8; + + /* If obs_rwd_buffer is not initialized, allocate memory for it now */ + if (!obs_rwd_buffer) { + obs_rwd_buffer = (unsigned char*) malloc(obs_rwd_buffer_size); + } + + if (!raw_image) + return; + + /* Copy RGB image into a flat u8 array (obs_rwd_buffer) */ + int i = 0; + for (int w=0; wgetPixel(w, h).color; + obs_rwd_buffer[i] = (c>>16) & 0xff; // R + obs_rwd_buffer[i+1] = (c>>8) & 0xff; // G + obs_rwd_buffer[i+2] = c & 0xff; // B + i = i + 3; + } + } + + /* Encode the reward (double) as 8 bytes at the end of the buffer */ + char *rewardBytes = (char*)&g_reward; + for (int j=0; j<8; j++) { + obs_rwd_buffer[i] = rewardBytes[j]; + i++; + } + + /* Send the obs_rwd_buffer over TCP to Python */ + n_send = send(pyserv_conn, obs_rwd_buffer, obs_rwd_buffer_size, 0); + + /* Receive a buffer of bytes with the actions to take */ + n_recv = recv(pyserv_conn, &actions, sizeof(actions), 0); + + virtual_key_presses[KeyType::FORWARD] = actions[0]; + virtual_key_presses[KeyType::BACKWARD] = actions[1]; + virtual_key_presses[KeyType::LEFT] = actions[2]; + virtual_key_presses[KeyType::RIGHT] = actions[3]; + virtual_key_presses[KeyType::JUMP] = actions[4]; + virtual_key_presses[KeyType::AUX1] = actions[5]; + virtual_key_presses[KeyType::SNEAK] = actions[6]; + virtual_key_presses[KeyType::ZOOM] = actions[7]; + virtual_key_presses[KeyType::DIG] = actions[8]; + virtual_key_presses[KeyType::PLACE] = actions[9]; + virtual_key_presses[KeyType::DROP] = actions[10]; + + /* Handle inventory open/close */ + if (actions[11]) { + if (g_menumgr.m_stack.empty()) { // if no menu is active + virtual_key_presses[KeyType::INVENTORY] = true; // open the inventory + } else { // if the inventory is open + /* Simulate pressing ESC key to close the inventory */ + SEvent ev{}; + ev.EventType = EET_KEY_INPUT_EVENT; + ev.KeyInput.Key = KEY_ESCAPE; + ev.KeyInput.Control = false; + ev.KeyInput.Shift = false; + ev.KeyInput.PressedDown = true; + ev.KeyInput.Char = 0; + + GUIModalMenu *mm = dynamic_cast(g_menumgr.m_stack.back()); + mm->OnEvent(ev); + } + } + + /* Handle mouse events when menu is open */ + if (!g_menumgr.m_stack.empty()) { + SEvent mouse_event{}; + mouse_event.EventType = EET_MOUSE_INPUT_EVENT; + + auto control = RenderingEngine::get_raw_device()->getCursorControl(); + auto pos = control->getPosition(); + + mouse_event.MouseInput.X = pos.X + ((signed char)actions[22] * 256 + (uint8_t)actions[21]); + mouse_event.MouseInput.Y = pos.Y + ((signed char)actions[24] * 256 + (uint8_t)actions[23]); + + /* Action DIG triggers left mouse click (see: wiki.minetest.net/Controls) */ + if (virtual_key_presses[KeyType::DIG]) { + mouse_event.MouseInput.Event = EMIE_LMOUSE_PRESSED_DOWN; + mouse_event.MouseInput.ButtonStates = EMBSM_LEFT; + } else if (virtual_key_presses[KeyType::PLACE]) { + mouse_event.MouseInput.Event = EMIE_RMOUSE_PRESSED_DOWN; + mouse_event.MouseInput.ButtonStates = EMBSM_RIGHT; + } else { + mouse_event.MouseInput.Event = EMIE_MOUSE_MOVED; + } + + GUIModalMenu *mm = dynamic_cast(g_menumgr.m_stack.back()); + mm->preprocessEvent(mouse_event); + } + + /* Hotbar item selection */ + virtual_key_presses[KeyType::SLOT_1] = actions[12]; + virtual_key_presses[KeyType::SLOT_2] = actions[13]; + virtual_key_presses[KeyType::SLOT_3] = actions[14]; + virtual_key_presses[KeyType::SLOT_4] = actions[15]; + virtual_key_presses[KeyType::SLOT_5] = actions[16]; + virtual_key_presses[KeyType::SLOT_6] = actions[17]; + virtual_key_presses[KeyType::SLOT_7] = actions[18]; + virtual_key_presses[KeyType::SLOT_8] = actions[19]; + virtual_key_presses[KeyType::SLOT_9] = actions[20]; + + /* Mouse movement: each position is stored in 2 bytes */ + virtual_mouse_x = (signed char)actions[22] * 256 + (uint8_t)actions[21]; + virtual_mouse_y = (signed char)actions[24] * 256 + (uint8_t)actions[23]; + + /* If sending or receiving went wrong, print an error message and quit */ + if (n_send + n_recv < 2) { + printf("[!!] Python client disconnected. Shutting down...\n"); + exit(43); + } } void Client::migrateModStorage() @@ -408,6 +603,8 @@ void Client::connect(const Address &address, const std::string &address_name, void Client::step(float dtime) { + syncClientStep(); + // Limit a bit if (dtime > DTIME_LIMIT) dtime = DTIME_LIMIT; @@ -420,6 +617,10 @@ void Client::step(float dtime) ReceiveAll(); + /* Clear virtual key presses */ + for (int i=0; istep(dtime); diff --git a/src/client/client.h b/src/client/client.h index b2ff9a0da..316d4b9f9 100644 --- a/src/client/client.h +++ b/src/client/client.h @@ -489,6 +489,15 @@ class Client : public con::PeerHandler, public InventoryManager, public IGameDef MtEventManager *m_event; RenderingEngine *m_rendering_engine; + /* Python API server related */ + int pyserv_port = 4343; + int pyserv_sockfd = 0; + int pyserv_conn = 0; + struct sockaddr_in *pyserv_servaddr = nullptr; + struct sockaddr_in *pyserv_cliaddr = nullptr; + unsigned char *obs_rwd_buffer = 0; + void startPyServer(); + void pyServerListener(); std::unique_ptr m_mesh_update_manager; ClientEnvironment m_env; diff --git a/src/client/game.cpp b/src/client/game.cpp index d35bf8e08..8faa30b86 100644 --- a/src/client/game.cpp +++ b/src/client/game.cpp @@ -19,6 +19,7 @@ with this program; if not, write to the Free Software Foundation, Inc., #include "game.h" +#include #include #include #include "client/renderingengine.h" @@ -84,6 +85,8 @@ with this program; if not, write to the Free Software Foundation, Inc., #include "client/sound/sound_openal.h" #endif +#include "sync.h" + /* Text input system */ @@ -1226,9 +1229,12 @@ void Game::run() updateFrame(&graph, &stats, dtime, cam_view); updateProfilerGraphs(&graph); + /* When running headless we assume that the window is always focused, thus, this block is never run */ + /* if (m_does_lost_focus_pause_game && !device->isWindowFocused() && !isMenuActive()) { showPauseMenu(); } + */ } RenderingEngine::autosaveScreensizeAndCo(initial_screen_size, initial_window_maximized); @@ -1445,7 +1451,7 @@ void Game::copyServerClientCache() { // It would be possible to let the client directly read the media files // from where the server knows they are. But aside from being more complicated - // it would also *not* fill the media cache and cause slower joining of + // it would also *not* fill the media cache and cause slower joining of // remote servers. // (Imagine that you launch a game once locally and then connect to a server.) @@ -2008,7 +2014,9 @@ void Game::updateStats(RunStats *stats, const FpsControl &draw_times, void Game::processUserInput(f32 dtime) { // Reset input if window not active or some menu is active - if (!device->isWindowActive() || isMenuActive() || guienv->hasFocus(gui_chat_console)) { + // bool win_active = device->isWindowActive(); + bool win_active = true; // Assume window is always active when running in headless mode + if (!win_active || isMenuActive() || guienv->hasFocus(gui_chat_console)) { if (m_game_focused) { m_game_focused = false; infostream << "Game lost focus" << std::endl; @@ -2629,8 +2637,11 @@ void Game::updateCameraDirection(CameraOrientation *cam, float dtime) if (cur_control) cur_control->setRelativeMode(!g_touchscreengui && !isMenuActive()); - if ((device->isWindowActive() && device->isWindowFocused() - && !isMenuActive()) || input->isRandom()) { + /* Simulate window is active and focused when running headless */ + // bool window_active = device->isWindowActive() && device->isWindowFocused(); + bool window_active = true; + + if ((window_active && !isMenuActive()) || input->isRandom()) { if (cur_control && !input->isRandom()) { // Mac OSX gets upset if this is set every frame @@ -2764,9 +2775,14 @@ void Game::updatePauseState() inline void Game::step(f32 dtime) { if (server) { + + /* Simulate window is active and focused when running headless */ + float fps_max = g_settings->getFloat("fps_max"); + /* float fps_max = (!device->isWindowFocused() || g_menumgr.pausesGame()) ? g_settings->getFloat("fps_max_unfocused") : g_settings->getFloat("fps_max"); + */ fps_max = std::max(fps_max, 1.0f); /* * Unless you have a barebones game, running the server at more than 60Hz @@ -2785,7 +2801,7 @@ inline void Game::step(f32 dtime) } if (!m_is_paused) - client->step(dtime); + client->step(dtime); } static void pauseNodeAnimation(PausedNodesList &paused, scene::ISceneNode *node) { diff --git a/src/client/inputhandler.cpp b/src/client/inputhandler.cpp index 6dfd2ad35..99ac45b80 100644 --- a/src/client/inputhandler.cpp +++ b/src/client/inputhandler.cpp @@ -24,6 +24,9 @@ with this program; if not, write to the Free Software Foundation, Inc., #include "gui/touchscreengui.h" #include "hud.h" +#include "sync.h" +#include + void KeyCache::populate_nonchanging() { key[KeyType::ESC] = EscapeKey; @@ -192,12 +195,13 @@ bool MyEventReceiver::OnEvent(const SEvent &event) /* * RealInputHandler */ + float RealInputHandler::getMovementSpeed() { - bool f = m_receiver->IsKeyDown(keycache.key[KeyType::FORWARD]), - b = m_receiver->IsKeyDown(keycache.key[KeyType::BACKWARD]), - l = m_receiver->IsKeyDown(keycache.key[KeyType::LEFT]), - r = m_receiver->IsKeyDown(keycache.key[KeyType::RIGHT]); + bool f = m_receiver->IsKeyDown(keycache.key[KeyType::FORWARD]) || virtual_key_presses[KeyType::FORWARD], + b = m_receiver->IsKeyDown(keycache.key[KeyType::BACKWARD]) || virtual_key_presses[KeyType::BACKWARD], + l = m_receiver->IsKeyDown(keycache.key[KeyType::LEFT]) || virtual_key_presses[KeyType::LEFT], + r = m_receiver->IsKeyDown(keycache.key[KeyType::RIGHT]) || virtual_key_presses[KeyType::RIGHT]; if (f || b || l || r) { // if contradictory keys pressed, stay still @@ -219,13 +223,13 @@ float RealInputHandler::getMovementDirection() float x = 0, z = 0; /* Check keyboard for input */ - if (m_receiver->IsKeyDown(keycache.key[KeyType::FORWARD])) + if (m_receiver->IsKeyDown(keycache.key[KeyType::FORWARD]) || virtual_key_presses[KeyType::FORWARD]) z += 1; - if (m_receiver->IsKeyDown(keycache.key[KeyType::BACKWARD])) + if (m_receiver->IsKeyDown(keycache.key[KeyType::BACKWARD]) || virtual_key_presses[KeyType::BACKWARD]) z -= 1; - if (m_receiver->IsKeyDown(keycache.key[KeyType::RIGHT])) + if (m_receiver->IsKeyDown(keycache.key[KeyType::RIGHT]) || virtual_key_presses[KeyType::RIGHT]) x += 1; - if (m_receiver->IsKeyDown(keycache.key[KeyType::LEFT])) + if (m_receiver->IsKeyDown(keycache.key[KeyType::LEFT]) || virtual_key_presses[KeyType::LEFT]) x -= 1; if (x != 0 || z != 0) /* If there is a keyboard event, it takes priority */ diff --git a/src/client/inputhandler.h b/src/client/inputhandler.h index f4fae2b0b..68c67658b 100644 --- a/src/client/inputhandler.h +++ b/src/client/inputhandler.h @@ -25,6 +25,13 @@ with this program; if not, write to the Free Software Foundation, Inc., #include "keycode.h" #include "renderingengine.h" +// For the python API server +#include +#include +#include + +#include "sync.h" + class InputHandler; /**************************************************************************** @@ -285,11 +292,11 @@ class RealInputHandler : public InputHandler virtual bool isKeyDown(GameKeyType k) { - return m_receiver->IsKeyDown(keycache.key[k]) || joystick.isKeyDown(k); + return m_receiver->IsKeyDown(keycache.key[k]) || joystick.isKeyDown(k) || virtual_key_presses[k]; } virtual bool wasKeyDown(GameKeyType k) { - return m_receiver->WasKeyDown(keycache.key[k]) || joystick.wasKeyDown(k); + return m_receiver->WasKeyDown(keycache.key[k]) || joystick.wasKeyDown(k) || virtual_key_presses[k]; } virtual bool wasKeyPressed(GameKeyType k) { @@ -330,8 +337,15 @@ class RealInputHandler : public InputHandler virtual v2s32 getMousePos() { auto control = RenderingEngine::get_raw_device()->getCursorControl(); + + m_mousepos.X += virtual_mouse_x; + m_mousepos.Y += virtual_mouse_y; + if (control) { - return control->getPosition(); + auto pos = control->getPosition(); + pos.X += virtual_mouse_x; + pos.Y += virtual_mouse_y; + return pos; } return m_mousepos; @@ -365,7 +379,7 @@ class RealInputHandler : public InputHandler } private: - MyEventReceiver *m_receiver = nullptr; + MyEventReceiver *m_receiver = nullptr; v2s32 m_mousepos; }; diff --git a/src/client/sync.h b/src/client/sync.h new file mode 100644 index 000000000..f04b94c41 --- /dev/null +++ b/src/client/sync.h @@ -0,0 +1,165 @@ +#pragma once + +#include "keys.h" + +#include +#include +#include +#include +#include +#include +#include + +/* + + "Virtual" keyboard input handling + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +*/ +inline bool virtual_key_presses[KeyType::INTERNAL_ENUM_COUNT]; +inline int virtual_mouse_x = 0; +inline int virtual_mouse_y = 0; + + +/* + + Synchronization between minetest's server and client + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +*/ +#define SYNC_PORT 44445 + +inline int sync_client_fd = -1; +inline int sync_conn_fd = -1; + +inline int syncServerInit() { + int server_fd; + // ssize_t valread; + struct sockaddr_in address; + int opt = 1; + socklen_t addrlen = sizeof(address); + // char buffer[1024] = { 0 }; + // char* hello = "Hello from server"; + + // Creating socket file descriptor + if ((server_fd = socket(AF_INET, SOCK_STREAM, 0)) < 0) { + perror("socket failed"); + exit(EXIT_FAILURE); + } + + // Forcefully attaching socket to the port 8080 + if (setsockopt(server_fd, SOL_SOCKET, + SO_REUSEADDR | SO_REUSEPORT, &opt, + sizeof(opt))) { + printf("[SyncServer] setsockopt failed\n"); + exit(EXIT_FAILURE); + } + address.sin_family = AF_INET; + address.sin_addr.s_addr = INADDR_ANY; + address.sin_port = htons(SYNC_PORT); + + // Forcefully attaching socket to the port + if (bind(server_fd, (struct sockaddr*)&address, + sizeof(address)) + < 0) { + printf("[SyncServer] Bind failed\n"); + exit(EXIT_FAILURE); + } + if (listen(server_fd, 3) < 0) { + printf("[SyncServer] Failed to listen\n"); + exit(EXIT_FAILURE); + } + if ((sync_conn_fd + = accept(server_fd, (struct sockaddr*)&address, &addrlen)) < 0) { + printf("[SyncServer] Connection accept error\n"); + exit(EXIT_FAILURE); + } + + // Set receive and send timeout on the socket + struct timeval timeout; + timeout.tv_sec = 2; /* timeout time in seconds */ + timeout.tv_usec = 0; + if (setsockopt(sync_conn_fd, SOL_SOCKET, SO_RCVTIMEO, (const char*)&timeout, sizeof(timeout)) < 0) { + printf("[SyncServer] setsockopt failed\n"); + exit(EXIT_FAILURE); + } + if (setsockopt(sync_conn_fd, SOL_SOCKET, SO_SNDTIMEO, (const char*)&timeout, sizeof(timeout)) < 0) { + printf("[SyncServer] setsockopt failed\n"); + exit(EXIT_FAILURE); + } + + printf("=> Sync Server connected\n"); + + return 0; +} + +inline int syncClientInit() { + int status; + struct sockaddr_in serv_addr; + if ((sync_client_fd = socket(AF_INET, SOCK_STREAM, 0)) < 0) { + printf("[SycClient] Socket creation error\n"); + return -1; + } + + serv_addr.sin_family = AF_INET; + serv_addr.sin_port = htons(SYNC_PORT); + + // Convert IPv4 and IPv6 addresses from text to binary + // form + if (inet_pton(AF_INET, "127.0.0.1", &serv_addr.sin_addr) + <= 0) { + printf("[SycClient] Invalid address\n"); + return -1; + } + + if ((status + = connect(sync_client_fd, (struct sockaddr*)&serv_addr, + sizeof(serv_addr))) + < 0) { + printf("[SycClient] Connection Failed\n"); + return -1; + } + + printf("=> Sync Client connected\n"); + return 0; +} + +inline void syncServerStep() { + if (sync_conn_fd == -1) + syncServerInit(); + + char msg[2]; + read(sync_conn_fd, msg, 2); +} + +inline void syncClientStep() { + if (sync_client_fd == -1) + syncClientInit(); + + /* Send a dummy message of two bytes */ + send(sync_client_fd, "-", 2, 0); +} + +/* + + Reward system + ~~~~~~~~~~~~~ + +*/ +inline double g_reward = 0.0; /* Global variable of the reward value */ + +extern "C" { +#include +} + +/* Implementation of the Lua functions to get/set the global reward value */ +inline static int lua_set_reward(lua_State *L) { + double d = lua_tonumber(L, 1); /* get argument */ + g_reward = d; + return 0; /* number of results */ +} + +inline static int lua_get_reward(lua_State *L) { + lua_pushnumber(L, g_reward); + return 1; /* number of results */ +} diff --git a/src/script/scripting_server.cpp b/src/script/scripting_server.cpp index 324850011..f8a0721b1 100644 --- a/src/script/scripting_server.cpp +++ b/src/script/scripting_server.cpp @@ -51,6 +51,8 @@ extern "C" { #include } +#include "../client/sync.h" + ServerScripting::ServerScripting(Server* server): ScriptApiBase(ScriptingType::Server), asyncEngine(server) @@ -79,6 +81,10 @@ ServerScripting::ServerScripting(Server* server): lua_newtable(L); lua_setfield(L, -2, "luaentities"); + /* Functions to get/set the global reward value */ + lua_register(L, "set_reward", lua_set_reward); + lua_register(L, "get_reward", lua_get_reward); + // Initialize our lua_api modules InitializeModApi(L, top); lua_pop(L, 1); diff --git a/src/server.cpp b/src/server.cpp index 316f349b2..0fac0c89b 100644 --- a/src/server.cpp +++ b/src/server.cpp @@ -76,6 +76,8 @@ with this program; if not, write to the Free Software Foundation, Inc., #include "particles.h" #include "gettext.h" +#include "client/sync.h" + class ClientNotFoundException : public BaseException { public: @@ -612,6 +614,8 @@ void Server::step() void Server::AsyncRunStep(float dtime, bool initial_step) { + syncServerStep(); + { // Send blocks to clients SendBlocks(dtime); diff --git a/test.py b/test.py new file mode 100644 index 000000000..06e4d8eb0 --- /dev/null +++ b/test.py @@ -0,0 +1,37 @@ +import time + +from craftium import CraftiumEnv + +import numpy as np +import matplotlib.pyplot as plt + +if __name__ == "__main__": + env = CraftiumEnv( + # render_mode="human", + # max_timesteps=15, + ) + iters = 100 + + observation, info = env.reset() + + start = time.time() + for i in range(iters): + # plt.clf() + # plt.imshow(np.transpose(observation, (1, 0, 2))) + # plt.pause(1e-7) + + # action = env.action_space.sample() + action = dict() + observation, reward, terminated, truncated, _info = env.step(action) + + print(i, reward, terminated, truncated) + + # time.sleep(1) + + if terminated or truncated: + observation, info = env.reset() + + end = time.time() + print(f"** {iters} frames in {end-start}s => {(end-start)/iters} per frame") + + env.close() diff --git a/wrapper.py b/wrapper.py new file mode 100644 index 000000000..7dce227bf --- /dev/null +++ b/wrapper.py @@ -0,0 +1,83 @@ +import subprocess +from threading import Thread +import socket +import time +import numpy as np +import matplotlib.pyplot as plt +import os +import struct + + +def launch_minetest_thread(cmd): + def launch_command(cmd): + subprocess.run(cmd) + + t = Thread(target=launch_command, args=[cmd]) + t.start() + + +def obs_client(port, obs_dim, host="127.0.0.1"): + obs_w, obs_h = obs_dim + obs_bytes = obs_w*obs_h*3 + 8 # the RGB image + 8 bytes of the reward value + + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.connect((host, port)) + + # main interaction loop + frame = 0 + while True: + # receive the whole image's bytes + data = [] + while len(data) < obs_bytes: + data += s.recv(obs_bytes) + + # obtain the reward value + reward_bytes = bytes(data[-8:]) # the last 8 bytes + # uncpack the double (float in python) in native endianess + reward = struct.unpack("d", bytes(reward_bytes))[0] + print(f"Reward: {reward}") + + # obtain the observation RGB image + data = data[:-8] # get the image data, all bytes except the last 8 + # reshape received bytes into an image + img = np.fromiter(data, dtype=np.uint8, count=(obs_bytes-8)).reshape(obs_w, obs_h, 3) + + # plt.clf() + # plt.imshow(np.transpose(img, (1, 0, 2))) + # plt.pause(1e-7) + + # send actions message + # action = list(np.random.randint(2, size=21)) + action = list(np.zeros(21, dtype=np.int8)) + # action[11] = frame == 10 + + # mouse position delta + # x, y = (obs_w // 2) - np.random.randint(obs_w), (obs_h//2) - np.random.randint(obs_h) + # action += list(struct.pack("