Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 Fix render cursor old oopengl #16

Merged
merged 6 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ build-backend = "scikit_build_core.build"

[project]
name = "craftground"
version = "2.5.36"
version = "2.5.37"
description = "Lightweight Minecraft Environment for Reinforcement Learning"
readme = "README.md"
license = { file = "LICENSE" }
Expand Down
1 change: 1 addition & 0 deletions src/craftground/MinecraftEnv/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,4 @@ cmake_install.cmake
CMakeFiles/
.cache
compile_commands.json
_deps/
12 changes: 11 additions & 1 deletion src/craftground/MinecraftEnv/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,14 @@ simulationDistance:6
mkdir build
cd build
cmake ../src/main/cpp
```
```

# Docs on MinecraftEnv.kt
## ResetPhase
1. Initial: END_RESET
2. After reading initial environment: WAIT_INIT_ENDS

## IOPhase
1. Initial: BEGINNING
2. After reading initial environment: GOT_INITIAL_ENVIRONMENT_SHOULD_SEND_OBSERVATION
3.
16 changes: 13 additions & 3 deletions src/craftground/MinecraftEnv/src/main/cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,20 @@ project(framebuffer_capturer)
# Set the C++ standard
set(CMAKE_CXX_STANDARD 11)
set(CMAKE_CXX_STANDARD_REQUIRED True)
SET(CMAKE_FIND_PACKAGE_SORT_ORDER NATURAL)
SET(CMAKE_FIND_PACKAGE_SORT_DIRECTION DEC)
set(CMAKE_FIND_PACKAGE_SORT_ORDER NATURAL)
set(CMAKE_FIND_PACKAGE_SORT_DIRECTION DEC)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)

include(FetchContent)

FetchContent_Declare(
glm
GIT_REPOSITORY https://github.com/g-truc/glm.git
GIT_TAG bf71a834948186f4097caa076cd2663c69a10e1e #refs/tags/1.0.1
)

FetchContent_MakeAvailable(glm)

set(CRAFGROUND_NATIVE_DEBUG $ENV{CRAFGROUND_NATIVE_DEBUG})
if (CRAFGROUND_NATIVE_DEBUG)
message("CRAFGROUND_NATIVE_DEBUG=${CRAFGROUND_NATIVE_DEBUG}")
Expand Down Expand Up @@ -86,7 +96,7 @@ if (CRAFGROUND_NATIVE_DEBUG)
endif()

# Link with JNI and OpenGL libraries
target_link_libraries(native-lib ${JNI_LIBRARIES} ${OPENGL_LIBRARIES} ${PNG_LIBRARIES} ${ZLIB_LIBRARIES})
target_link_libraries(native-lib ${JNI_LIBRARIES} ${OPENGL_LIBRARIES} ${PNG_LIBRARIES} ${ZLIB_LIBRARIES} glm::glm)
if (NOT APPLE)
target_link_libraries(native-lib ${GLEW_LIBRARIES})
endif()
Expand Down
212 changes: 184 additions & 28 deletions src/craftground/MinecraftEnv/src/main/cpp/framebuffer_capturer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
#else
// #include <GL/gl.h>
#include <GL/glew.h>
#include <glm/glm.hpp>
#include <glm/gtc/matrix_transform.hpp>
#include <glm/gtc/type_ptr.hpp>
#endif

#include <cstring> // For strcmp
Expand Down Expand Up @@ -167,8 +170,33 @@ const GLubyte cursor[16][16] = {
};

GLuint cursorTexID;
GLuint cursorShaderProgram;
GLuint cursorVAO, cursorVBO, cursorEBO;

float cursorVertices[] = {
// Positions // Texture Coords
0.0f,
0.0f,
0.0f,
0.0f, // Bottom-left
1.0f,
0.0f,
1.0f,
0.0f, // Bottom-right
1.0f,
-1.0f,
1.0f,
1.0f, // Top-right
0.0f,
-1.0f,
0.0f,
1.0f // Top-left
};

// index data
unsigned int cursorIndices[] = {0, 1, 2, 2, 3, 0};

void initCursorTexture() {
bool initCursorTexture() {
glGenTextures(1, &cursorTexID);
glBindTexture(GL_TEXTURE_2D, cursorTexID);

Expand Down Expand Up @@ -219,6 +247,149 @@ void initCursorTexture() {
glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST);
glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_CLAMP_TO_EDGE);
glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE);

const char *vertexShaderSource = R"(
#version 330 core
layout(location = 0) in vec2 aPos; // Vertex position
layout(location = 1) in vec2 aTexCoord; // Texture coordinates

out vec2 TexCoord; // Texture coordinates to fragment shader

uniform mat4 projection;
uniform mat4 model;

void main() {
gl_Position = projection * model * vec4(aPos, 0.0, 1.0); // Vertex position
TexCoord = aTexCoord; // Pass the texture
}
)";

const char *fragmentShaderSource = R"(
#version 330 core
out vec4 FragColor;

in vec2 TexCoord; // Texture coordinates from vertex shader
uniform sampler2D uTexture; // Texture sampler

void main() {
FragColor = texture(uTexture, TexCoord); // Output the texture
}
)";
GLuint vertexShader = glCreateShader(GL_VERTEX_SHADER);
glShaderSource(vertexShader, 1, &vertexShaderSource, nullptr);
glCompileShader(vertexShader);

GLint success;
glGetShaderiv(vertexShader, GL_COMPILE_STATUS, &success);
if (!success) {
char infoLog[512];
glGetShaderInfoLog(vertexShader, 512, nullptr, infoLog);
printf("ERROR::SHADER::VERTEX::COMPILATION_FAILED\n%s\n", infoLog);
return false;
}

GLuint fragmentShader = glCreateShader(GL_FRAGMENT_SHADER);
glShaderSource(fragmentShader, 1, &fragmentShaderSource, nullptr);
glCompileShader(fragmentShader);

// Check for shader compile errors
glGetShaderiv(fragmentShader, GL_COMPILE_STATUS, &success);
if (!success) {
char infoLog[512];
glGetShaderInfoLog(fragmentShader, 512, nullptr, infoLog);
printf("ERROR::SHADER::FRAGMENT::COMPILATION_FAILED\n%s\n", infoLog);
return false;
}

cursorShaderProgram = glCreateProgram();
glAttachShader(cursorShaderProgram, vertexShader);
glAttachShader(cursorShaderProgram, fragmentShader);
glLinkProgram(cursorShaderProgram);

glGetProgramiv(cursorShaderProgram, GL_LINK_STATUS, &success);
if (!success) {
char infoLog[512];
glGetProgramInfoLog(cursorShaderProgram, 512, nullptr, infoLog);
printf("ERROR::SHADER::PROGRAM::LINKING_FAILED\n%s\n", infoLog);
return false;
}

// remove shaders (no longer needed after linking)
glDeleteShader(vertexShader);
glDeleteShader(fragmentShader);

glGenVertexArrays(1, &cursorVAO);
glGenBuffers(1, &cursorVBO);
glGenBuffers(1, &cursorEBO);

glBindVertexArray(cursorVAO);

glBindBuffer(GL_ARRAY_BUFFER, cursorVBO);
glBufferData(
GL_ARRAY_BUFFER, sizeof(cursorVertices), cursorVertices, GL_STATIC_DRAW
);

glBindBuffer(GL_ELEMENT_ARRAY_BUFFER, cursorEBO);
glBufferData(
GL_ELEMENT_ARRAY_BUFFER,
sizeof(cursorIndices),
cursorIndices,
GL_STATIC_DRAW
);

// Position attribute (aPos)
glVertexAttribPointer(
0, 2, GL_FLOAT, GL_FALSE, 4 * sizeof(float), (void *)0
);
glEnableVertexAttribArray(0);

// Texture attribute (aTexCoord)
glVertexAttribPointer(
1, 2, GL_FLOAT, GL_FALSE, 4 * sizeof(float), (void *)(2 * sizeof(float))
);
glEnableVertexAttribArray(1);

glBindVertexArray(0); // Unbind VAO

return true;
}

// TODO: USE shader
/*
void renderCursor(jint mouseX, jint mouseY) {
glBindTexture(GL_TEXTURE_2D, cursorTexID);
glEnable(GL_TEXTURE_2D);
glBegin(GL_QUADS);
glTexCoord2f(0.0f, 0.0f);
glVertex2f(mouseX, mouseY);
glTexCoord2f(1.0f, 0.0f);
glVertex2f(mouseX + 16, mouseY);
glTexCoord2f(1.0f, 1.0f);
glVertex2f(mouseX + 16, mouseY - 16);
glTexCoord2f(0.0f, 1.0f);
glVertex2f(mouseX, mouseY - 16);
glEnd();
glDisable(GL_TEXTURE_2D);
}
*/

void renderCursor(jint mouseX, jint mouseY) {
glUseProgram(cursorShaderProgram);
glActiveTexture(GL_TEXTURE0);
glBindTexture(GL_TEXTURE_2D, cursorTexID);
glUniform1i(glGetUniformLocation(cursorShaderProgram, "uTexture"), 0);
glm::mat4 model = glm::mat4(1.0f);
model = glm::translate(model, glm::vec3(mouseX, mouseY, 0.0f));
model = glm::scale(model, glm::vec3(16.0f, 16.0f, 1.0f));
glUniformMatrix4fv(
glGetUniformLocation(cursorShaderProgram, "model"),
1,
GL_FALSE,
glm::value_ptr(model)
);
glBindVertexArray(cursorVAO);
glDrawElements(GL_TRIANGLES, 6, GL_UNSIGNED_INT, 0);
glBindVertexArray(0);
}

extern "C" JNIEXPORT jobject JNICALL
Expand Down Expand Up @@ -378,6 +549,11 @@ Java_com_kyhsgeekcode_minecraftenv_FramebufferCapturer_initializeZerocopyImpl(
jint depthAttachment,
jint python_pid
) {
if (!initCursorTexture()) {
fflush(stderr);
fflush(stdout);
return nullptr;
}
jclass byteStringClass = env->FindClass("com/google/protobuf/ByteString");
if (byteStringClass == nullptr || env->ExceptionCheck()) {
return nullptr;
Expand Down Expand Up @@ -432,20 +608,7 @@ Java_com_kyhsgeekcode_minecraftenv_FramebufferCapturer_captureFramebufferZerocop
) {
glBindFramebuffer(GL_READ_FRAMEBUFFER, frameBufferId);
if (drawCursor) {
glBindTexture(GL_TEXTURE_2D, cursorTexID);

glEnable(GL_TEXTURE_2D);
glBegin(GL_QUADS);
glTexCoord2f(0.0f, 0.0f);
glVertex2f(mouseX, mouseY);
glTexCoord2f(1.0f, 0.0f);
glVertex2f(mouseX + 16, mouseY);
glTexCoord2f(1.0f, 1.0f);
glVertex2f(mouseX + 16, mouseY - 16);
glTexCoord2f(0.0f, 1.0f);
glVertex2f(mouseX, mouseY - 16);
glEnd();
glDisable(GL_TEXTURE_2D);
renderCursor(mouseX, mouseY);
}

// It could have been that the rendered image is already being shared,
Expand All @@ -468,6 +631,11 @@ Java_com_kyhsgeekcode_minecraftenv_FramebufferCapturer_initializeZerocopyImpl(
jint depthAttachment,
jint python_pid
) {
if (!initCursorTexture()) {
fflush(stderr);
fflush(stdout);
return nullptr;
}
jclass runtimeExceptionClass = env->FindClass("java/lang/RuntimeException");
if (runtimeExceptionClass == nullptr) {
fprintf(stderr, "Failed to find RuntimeException class\n");
Expand Down Expand Up @@ -548,19 +716,7 @@ Java_com_kyhsgeekcode_minecraftenv_FramebufferCapturer_captureFramebufferZerocop
glBindFramebuffer(GL_READ_FRAMEBUFFER, frameBufferId);

if (drawCursor) {
glBindTexture(GL_TEXTURE_2D, cursorTexID);
glEnable(GL_TEXTURE_2D);
glBegin(GL_QUADS);
glTexCoord2f(0.0f, 0.0f);
glVertex2f(mouseX, mouseY);
glTexCoord2f(1.0f, 0.0f);
glVertex2f(mouseX + 16, mouseY);
glTexCoord2f(1.0f, 1.0f);
glVertex2f(mouseX + 16, mouseY - 16);
glTexCoord2f(0.0f, 1.0f);
glVertex2f(mouseX, mouseY - 16);
glEnd();
glDisable(GL_TEXTURE_2D);
renderCursor(mouseX, mouseY);
}

// CUDA IPC handles are used to share the framebuffer with the Python side
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,15 @@ int initialize_cuda_ipc(
return sizeof(cudaIpcMemHandle_t);
}

void checkAndPrintGLError() {
GLenum error = glGetError();
while (error != GL_NO_ERROR) {
printf("OpenGL Error: 0x%x\n", error);
error = glGetError();
}
fflush(stdout);
}

void copyFramebufferToCudaSharedMemory(int width, int height) {
GLuint renderedTextureId;
glGetFramebufferAttachmentParameteriv(
Expand All @@ -89,29 +98,37 @@ void copyFramebufferToCudaSharedMemory(int width, int height) {
GL_FRAMEBUFFER_ATTACHMENT_OBJECT_NAME,
(GLint *)&renderedTextureId
);
checkAndPrintGLError();
glBindTexture(GL_TEXTURE_2D, renderedTextureId);
checkAndPrintGLError();
int textureWidth, textureHeight;
int format;
glGetTexLevelParameteriv(GL_TEXTURE_2D, 0, GL_TEXTURE_WIDTH, &textureWidth);
checkAndPrintGLError();
glGetTexLevelParameteriv(
GL_TEXTURE_2D, 0, GL_TEXTURE_HEIGHT, &textureHeight
);
checkAndPrintGLError();
glGetTexLevelParameteriv(
GL_TEXTURE_2D, 0, GL_TEXTURE_INTERNAL_FORMAT, &format
);
checkAndPrintGLError();
// printf("width: %d, height: %d, format: %d\n", textureWidth,
// textureHeight, format); fflush(stdout);
assert(format == GL_RGBA8);
// printf("width: %d, height: %d\n", textureWidth, textureHeight);
glViewport(0, 0, width, height);
checkAndPrintGLError();
glReadBuffer(GL_COLOR_ATTACHMENT0);
checkAndPrintGLError();
GLenum status = glCheckFramebufferStatus(GL_READ_FRAMEBUFFER);
checkAndPrintGLError();
if (status != GL_FRAMEBUFFER_COMPLETE) {
printf("Framebuffer is not complete! Status: 0x%x\n", status);
fflush(stdout);
assert(status == GL_FRAMEBUFFER_COMPLETE);
}
assert(glGetError() == GL_NO_ERROR);
fflush(stdout);
assert(width == textureWidth);
assert(height == textureHeight);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,14 @@ class EnvironmentInitializer(
private val initialEnvironment: InitialEnvironmentMessage,
private val csvLogger: CsvLogger,
) {
var hasRunInitWorld: Boolean = false
private var hasRunInitWorld: Boolean = false
private set
var initWorldFinished: Boolean = false
private set

private lateinit var minecraftServer: MinecraftServer
private lateinit var player: ClientPlayerEntity
var hasMinimizedWindow: Boolean = false
private var hasMinimizedWindow: Boolean = false

private var initializedClient = false
private var finishedEnteringWorld = false
Expand Down
Loading
Loading