Skip to content

Commit

Permalink
feat: add support for file upload
Browse files Browse the repository at this point in the history
Enhances extension compatibility with file upload changes from
vaadin/hilla#3165.
Updates EndpointController to utilize a custom HttpServletRequest,
ensuring seamless integration with Quarkus Multipart Form data handling.
  • Loading branch information
mcollovati committed Feb 1, 2025
1 parent 5e2a573 commit 90d6b6b
Show file tree
Hide file tree
Showing 8 changed files with 396 additions and 24 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Copyright 2025 Marco Collovati, Dario Götze
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.github.mcollovati.quarkus.hilla.deployment.asm;

import io.quarkus.gizmo.Gizmo;
import org.objectweb.asm.ClassVisitor;
import org.objectweb.asm.MethodVisitor;
import org.objectweb.asm.Opcodes;
import org.objectweb.asm.tree.MethodInsnNode;
import org.objectweb.asm.tree.MethodNode;
import org.objectweb.asm.tree.TypeInsnNode;

public class EndpointControllerVisitor extends ClassVisitor {

public static final String SPRING_MULTIPART_HTTP_SERVLET_REQUEST =
"org/springframework/web/multipart/MultipartHttpServletRequest";
public static final String QH_MULTIPART_HTTP_SERVLET_REQUEST =
"com/github/mcollovati/quarkus/hilla/multipart/MultipartRequest";

public EndpointControllerVisitor(ClassVisitor classVisitor) {
super(Gizmo.ASM_API_VERSION, classVisitor);
}

@Override
public MethodVisitor visitMethod(
int access, String name, String descriptor, String signature, String[] exceptions) {
MethodVisitor superVisitor = super.visitMethod(access, name, descriptor, signature, exceptions);
if ("doServeEndpoint".equals(name)) {
return new MethodNode(Gizmo.ASM_API_VERSION, access, name, descriptor, signature, exceptions) {
@Override
public void visitEnd() {
var iterator = instructions.iterator();
TypeInsnNode checkCastNode = AsmUtils.findNextInsnNode(
iterator,
node -> node.getOpcode() == Opcodes.CHECKCAST
&& node.desc.equals(SPRING_MULTIPART_HTTP_SERVLET_REQUEST),
TypeInsnNode.class);
checkCastNode.desc = QH_MULTIPART_HTTP_SERVLET_REQUEST;
MethodInsnNode getParameterNode = AsmUtils.findNextInsnNode(
iterator,
node -> node.getOpcode() == Opcodes.INVOKEINTERFACE
&& node.owner.equals(SPRING_MULTIPART_HTTP_SERVLET_REQUEST)
&& node.name.equals("getParameter"),
MethodInsnNode.class);
getParameterNode.setOpcode(Opcodes.INVOKEVIRTUAL);
getParameterNode.owner = QH_MULTIPART_HTTP_SERVLET_REQUEST;
getParameterNode.itf = false;

MethodInsnNode getFileMapNode = AsmUtils.findNextInsnNode(
iterator,
node -> node.getOpcode() == Opcodes.INVOKEINTERFACE
&& node.owner.equals(SPRING_MULTIPART_HTTP_SERVLET_REQUEST)
&& node.name.equals("getFileMap"),
MethodInsnNode.class);
getFileMapNode.setOpcode(Opcodes.INVOKEVIRTUAL);
getFileMapNode.owner = QH_MULTIPART_HTTP_SERVLET_REQUEST;
getFileMapNode.itf = false;
accept(superVisitor);
}
};
}
return superVisitor;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ public static void addClassVisitors(BuildProducer<BytecodeTransformerBuildItem>
}));
producer.produce(applicationContextProvider_runOnContext_patch());
producer.produce(endpointCodeGenerator_findBrowserCallables_replacement());
producer.produce(new BytecodeTransformerBuildItem(
"com.vaadin.hilla.EndpointController",
(className, classVisitor) -> new EndpointControllerVisitor(classVisitor)));
}

@SafeVarargs
Expand Down
1 change: 1 addition & 0 deletions commons/hilla-shaded-deps/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@
<artifact>org.springframework:spring-web</artifact>
<includes>
<include>org/springframework/http/*.class</include>
<include>org/springframework/web/multipart/MultipartFile.class</include>
</includes>
</filter>
<filter>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,53 +18,57 @@
import jakarta.inject.Inject;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import jakarta.ws.rs.Consumes;
import jakarta.ws.rs.POST;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.PathParam;
import jakarta.ws.rs.Produces;
import jakarta.ws.rs.core.Context;
import jakarta.ws.rs.core.MediaType;
import jakarta.ws.rs.core.Response;
import java.io.IOException;

import com.fasterxml.jackson.databind.node.ObjectNode;
import com.vaadin.hilla.Endpoint;
import com.vaadin.hilla.EndpointController;
import com.vaadin.hilla.EndpointInvoker;
import com.vaadin.hilla.EndpointRegistry;
import com.vaadin.hilla.auth.CsrfChecker;
import org.springframework.context.ApplicationContext;
import org.jboss.resteasy.reactive.server.multipart.MultipartFormDataInput;
import org.springframework.http.ResponseEntity;

import com.github.mcollovati.quarkus.hilla.multipart.MultipartRequest;

@Path("")
public class QuarkusEndpointController {

static final String ENDPOINT_METHODS = "/{endpoint}/{method}";

private final EndpointController delegate;

/**
* A constructor used to initialize the controller.
*
* @param context Spring context to extract beans annotated with
* {@link Endpoint} from
* @param endpointRegistry the registry used to store endpoint information
* @param endpointInvoker then end point invoker
* @param csrfChecker the csrf checker to use
*/
public QuarkusEndpointController(
ApplicationContext context,
EndpointRegistry endpointRegistry,
EndpointInvoker endpointInvoker,
CsrfChecker csrfChecker) {
delegate = new EndpointController(context, endpointRegistry, endpointInvoker, csrfChecker);
}

@Inject
public QuarkusEndpointController(EndpointController delegate) {
this.delegate = delegate;
QuarkusHillaExtension.markUsed();
}

/**
* Captures and processes the Vaadin endpoint requests.
* <p>
* Matches the endpoint name and a method name with the corresponding Java
* class and a public method in the class. Extracts parameters from a
* request body if the Java method requires any and applies in the same
* order. After the method call, serializes the Java method execution result
* and sends it back.
* <p>
* If an issue occurs during the request processing, an error response is
* returned instead of the serialized Java method return value.
*
* @param endpointName the name of an endpoint to address the calls to, not case
* sensitive
* @param methodName the method name to execute on an endpoint, not case sensitive
* @param body optional request body, that should be specified if the method
* called has parameters
* @param request the current request which triggers the endpoint call
* @param response the current response
* @return execution result as a JSON string or an error message string
*/
@POST
@Path(ENDPOINT_METHODS)
@Produces(MediaType.APPLICATION_JSON)
Expand All @@ -77,6 +81,41 @@ public Response serveEndpoint(

ResponseEntity<String> endpointResponse =
delegate.serveEndpoint(endpointName, methodName, body, request, response);
return buildResponse(endpointResponse);
}

/**
* Captures and processes the Vaadin multipart endpoint requests. They are
* used when there are uploaded files.
* <p>
* This method works as
* {@link #serveEndpoint(String, String, HttpServletRequest, HttpServletResponse, ObjectNode)},
* but it also captures the files uploaded in the request.
*
* @param endpointName the name of an endpoint to address the calls to, not case
* sensitive
* @param methodName the method name to execute on an endpoint, not case sensitive
* @param request the current multipart request which triggers the endpoint call
* @param response the current response
* @return execution result as a JSON string or an error message string
*/
@POST
@Path(ENDPOINT_METHODS)
@Consumes(MediaType.MULTIPART_FORM_DATA)
@Produces(MediaType.APPLICATION_JSON)
public Response serveMultipartEndpoint(
@PathParam("endpoint") String endpointName,
@PathParam("method") String methodName,
@Context HttpServletRequest request,
@Context HttpServletResponse response,
MultipartFormDataInput formData)
throws IOException {
ResponseEntity<String> endpointResponse = delegate.serveMultipartEndpoint(
endpointName, methodName, new MultipartRequest(request, formData), response);
return buildResponse(endpointResponse);
}

private static Response buildResponse(ResponseEntity<String> endpointResponse) {
Response.ResponseBuilder builder =
Response.status(endpointResponse.getStatusCode().value());
endpointResponse.getHeaders().forEach((name, values) -> values.forEach(value -> builder.header(name, value)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,9 @@ EndpointController endpointController(
ApplicationContext context,
EndpointRegistry endpointRegistry,
EndpointInvoker endpointInvoker,
CsrfChecker csrfChecker) {
return new EndpointController(context, endpointRegistry, endpointInvoker, csrfChecker);
CsrfChecker csrfChecker,
@Named("endpointObjectMapper") ObjectMapper objectMapper) {
return new EndpointController(context, endpointRegistry, endpointInvoker, csrfChecker, objectMapper);
}

@Produces
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
/*
* Copyright 2025 Marco Collovati, Dario Götze
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.github.mcollovati.quarkus.hilla.multipart;

import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletRequestWrapper;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.Serializable;
import java.io.UncheckedIOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Collection;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Collectors;

import org.jboss.resteasy.reactive.server.multipart.FormValue;
import org.jboss.resteasy.reactive.server.multipart.MultipartFormDataInput;
import org.springframework.http.HttpHeaders;
import org.springframework.web.multipart.MultipartFile;

public class MultipartRequest extends HttpServletRequestWrapper {

private Map<String, Collection<FormValue>> formData;

/**
* Constructs a request object wrapping the given request.
*
* @param request the {@link HttpServletRequest} to be wrapped.
* @throws IllegalArgumentException if the request is null
*/
public MultipartRequest(HttpServletRequest request, MultipartFormDataInput formData) {
super(request);
this.formData = formData.getValues();
}

@Override
public String getParameter(String name) {
Collection<FormValue> values = formData.get(name);
return Optional.ofNullable(values).stream()
.flatMap(Collection::stream)
.filter(fv -> !fv.isFileItem())
.findFirst()
.map(FormValue::getValue)
.orElseGet(() -> super.getParameter(name));
}

public Map<String, MultipartFile> getFileMap() {
return formData.entrySet().stream()
.map(e -> Map.entry(e.getKey(), e.getValue().iterator().next()))
.filter(e -> e.getValue().isFileItem())
.map(e -> new MultipartFileImpl(e.getKey(), e.getValue()))
.collect(Collectors.toMap(MultipartFileImpl::getName, Function.identity()));
}

private static class MultipartFileImpl implements MultipartFile, Serializable {

private final String name;
private final FormValue formValue;

public MultipartFileImpl(String name, FormValue formValue) {
this.name = name;
this.formValue = formValue;
}

@Override
public String getName() {
return name;
}

@Override
public String getOriginalFilename() {
return formValue.getFileName();
}

@Override
public String getContentType() {
return formValue.getHeaders().getFirst(HttpHeaders.CONTENT_TYPE);
}

@Override
public boolean isEmpty() {
return getSize() <= 0L;
}

@Override
public long getSize() {
try {
return formValue.getFileItem().getFileSize();
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}

@Override
public byte[] getBytes() throws IOException {
return formValue.getFileItem().getInputStream().readAllBytes();
}

@Override
public InputStream getInputStream() throws IOException {
return formValue.getFileItem().getInputStream();
}

@Override
public void transferTo(File dest) throws IOException, IllegalStateException {
transferTo(dest.toPath());
}

@Override
public void transferTo(Path dest) throws IOException, IllegalStateException {
if (dest.toFile().exists()) {
Files.delete(dest);
}
formValue.getFileItem().write(dest);
}
}
}
Loading

0 comments on commit 90d6b6b

Please sign in to comment.