Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract interface for Runtime API client #471

Merged
merged 2 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
//
// AWSLambda.java
//
// Copyright (c) 2013 Amazon. All rights reserved.
//
/*
Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
SPDX-License-Identifier: Apache-2.0
*/
package com.amazonaws.services.lambda.runtime.api.client;

import com.amazonaws.services.lambda.crac.Core;
Expand All @@ -12,23 +11,25 @@
import com.amazonaws.services.lambda.runtime.api.client.logging.LambdaContextLogger;
import com.amazonaws.services.lambda.runtime.api.client.logging.LogSink;
import com.amazonaws.services.lambda.runtime.api.client.logging.StdOutLogSink;
import com.amazonaws.services.lambda.runtime.api.client.runtimeapi.InvocationRequest;
import com.amazonaws.services.lambda.runtime.api.client.runtimeapi.LambdaRuntimeClient;
import com.amazonaws.services.lambda.runtime.api.client.runtimeapi.LambdaRuntimeApiClient;
import com.amazonaws.services.lambda.runtime.api.client.runtimeapi.LambdaRuntimeApiClientImpl;
import com.amazonaws.services.lambda.runtime.api.client.runtimeapi.converters.LambdaErrorConverter;
import com.amazonaws.services.lambda.runtime.api.client.runtimeapi.converters.XRayErrorCauseConverter;
import com.amazonaws.services.lambda.runtime.api.client.runtimeapi.dto.InvocationRequest;
import com.amazonaws.services.lambda.runtime.api.client.runtimeapi.dto.LambdaError;
import com.amazonaws.services.lambda.runtime.api.client.runtimeapi.dto.XRayErrorCause;
import com.amazonaws.services.lambda.runtime.api.client.util.LambdaOutputStream;
import com.amazonaws.services.lambda.runtime.api.client.util.UnsafeUtil;
import com.amazonaws.services.lambda.runtime.logging.LogFormat;
import com.amazonaws.services.lambda.runtime.logging.LogLevel;
import com.amazonaws.services.lambda.runtime.serialization.PojoSerializer;
import com.amazonaws.services.lambda.runtime.serialization.factories.GsonFactory;
import com.amazonaws.services.lambda.runtime.serialization.factories.JacksonFactory;
import com.amazonaws.services.lambda.runtime.serialization.util.ReflectUtil;

import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileDescriptor;
import java.io.FileInputStream;
import java.io.IOError;
import java.io.IOException;
import java.io.OutputStream;
import java.io.PrintStream;
import java.lang.reflect.Constructor;
import java.net.URLClassLoader;
Expand Down Expand Up @@ -67,6 +68,10 @@ public class AWSLambda {

private static final String AWS_LAMBDA_INITIALIZATION_TYPE = System.getenv(ReservedRuntimeEnvironmentVariables.AWS_LAMBDA_INITIALIZATION_TYPE);

protected static URLClassLoader customerClassLoader;

private static LambdaRuntimeApiClient runtimeClient;

static {
// Override the disabledAlgorithms setting to match configuration for openjdk8-u181.
// This is to keep DES ciphers around while we deploying security updates.
Expand Down Expand Up @@ -143,17 +148,6 @@ public static void setupRuntimeLogger(LambdaLogger lambdaLogger)
);
}

public static String getEnvOrExit(String envVariableName) {
String value = System.getenv(envVariableName);
if (value == null) {
System.err.println("Could not get environment variable " + envVariableName);
System.exit(-1);
}
return value;
}

protected static URLClassLoader customerClassLoader;

/**
* convert an integer into a FileDescriptor object using reflection to access private members.
*/
Expand Down Expand Up @@ -207,8 +201,7 @@ private static void startRuntime(String handler, LambdaContextLogger lambdaLogge
System.setErr(new PrintStream(new LambdaOutputStream(System.err), false, "UTF-8"));
setupRuntimeLogger(lambdaLogger);

String runtimeApi = getEnvOrExit(ReservedRuntimeEnvironmentVariables.AWS_LAMBDA_RUNTIME_API);
LambdaRuntimeClient runtimeClient = new LambdaRuntimeClient(runtimeApi);
runtimeClient = new LambdaRuntimeApiClientImpl(LambdaEnvironment.RUNTIME_API);

String taskRoot = System.getProperty("user.dir");
String libRoot = "/opt/java";
Expand All @@ -223,17 +216,18 @@ private static void startRuntime(String handler, LambdaContextLogger lambdaLogge
requestHandler = findRequestHandler(handler, customerClassLoader);
} catch (UserFault userFault) {
lambdaLogger.log(userFault.reportableError(), lambdaLogger.getLogFormat() == LogFormat.JSON ? LogLevel.ERROR : LogLevel.UNDEFINED);
reportInitError(new Failure(userFault), runtimeClient);
LambdaError error = LambdaErrorConverter.fromUserFault(userFault);
runtimeClient.reportInitError(error);
System.exit(1);
return;
}
if (INIT_TYPE_SNAP_START.equals(AWS_LAMBDA_INITIALIZATION_TYPE)) {
onInitComplete(runtimeClient, lambdaLogger);
onInitComplete(lambdaLogger);
}
boolean shouldExit = false;
while (!shouldExit) {
UserFault userFault = null;
InvocationRequest request = runtimeClient.waitForNextInvocation();
InvocationRequest request = runtimeClient.nextInvocation();
if (request.getXrayTraceId() != null) {
System.setProperty(LAMBDA_TRACE_HEADER_PROP, request.getXrayTraceId());
} else {
Expand All @@ -243,26 +237,23 @@ private static void startRuntime(String handler, LambdaContextLogger lambdaLogge
ByteArrayOutputStream payload;
try {
payload = requestHandler.call(request);
runtimeClient.postInvocationResponse(request.getId(), payload.toByteArray());
runtimeClient.reportInvocationSuccess(request.getId(), payload.toByteArray());
boolean ignored = Thread.interrupted(); // clear interrupted flag in case if it was set by user's code
} catch (UserFault f) {
shouldExit = f.fatal;
userFault = f;
UserFault.filterStackTrace(f);
payload = new ByteArrayOutputStream(1024);
Failure failure = new Failure(f);
GsonFactory.getInstance().getSerializer(Failure.class).toJson(failure, payload);
shouldExit = f.fatal;
runtimeClient.postInvocationError(request.getId(), payload.toByteArray(), failure.getErrorType());

LambdaError error = LambdaErrorConverter.fromUserFault(f);
runtimeClient.reportInvocationError(request.getId(), error);
} catch (Throwable t) {
shouldExit = t instanceof VirtualMachineError || t instanceof IOError;
UserFault.filterStackTrace(t);
userFault = UserFault.makeUserFault(t);
payload = new ByteArrayOutputStream(1024);
Failure failure = new Failure(t);
GsonFactory.getInstance().getSerializer(Failure.class).toJson(failure, payload);
// These two categories of errors are considered fatal.
shouldExit = Failure.isInvokeFailureFatal(t);
runtimeClient.postInvocationError(request.getId(), payload.toByteArray(), failure.getErrorType(),
serializeAsXRayJson(t));

LambdaError error = LambdaErrorConverter.fromThrowable(t);
XRayErrorCause xRayErrorCause = XRayErrorCauseConverter.fromThrowable(t);
runtimeClient.reportInvocationError(request.getId(), error, xRayErrorCause);
} finally {
if (userFault != null) {
lambdaLogger.log(userFault.reportableError(), lambdaLogger.getLogFormat() == LogFormat.JSON ? LogLevel.ERROR : LogLevel.UNDEFINED);
Expand All @@ -271,23 +262,22 @@ private static void startRuntime(String handler, LambdaContextLogger lambdaLogge
}
}

static void onInitComplete(final LambdaRuntimeClient runtimeClient, final LambdaContextLogger lambdaLogger) throws IOException {
static void onInitComplete(final LambdaContextLogger lambdaLogger) throws IOException {
try {
Core.getGlobalContext().beforeCheckpoint(null);
// Blocking call to RAPID /restore/next API, will return after taking snapshot.
// This will also be the 'entrypoint' when resuming from snapshots.
runtimeClient.getRestoreNext();
runtimeClient.restoreNext();
} catch (Exception e1) {
logExceptionCloudWatch(lambdaLogger, e1);
reportInitError(new Failure(e1), runtimeClient);
LambdaError error = LambdaErrorConverter.fromThrowable(e1);
runtimeClient.reportInitError(error);
System.exit(64);
}
try {
Core.getGlobalContext().afterRestore(null);
} catch (Exception restoreExc) {
logExceptionCloudWatch(lambdaLogger, restoreExc);
Failure errorPayload = new Failure(restoreExc);
reportRestoreError(errorPayload, runtimeClient);
LambdaError error = LambdaErrorConverter.fromThrowable(restoreExc);
runtimeClient.reportRestoreError(error);
System.exit(64);
}
}
Expand All @@ -297,40 +287,4 @@ private static void logExceptionCloudWatch(LambdaContextLogger lambdaLogger, Exc
UserFault userFault = UserFault.makeUserFault(exc, true);
lambdaLogger.log(userFault.reportableError(), lambdaLogger.getLogFormat() == LogFormat.JSON ? LogLevel.ERROR : LogLevel.UNDEFINED);
}

static void reportInitError(final Failure failure,
final LambdaRuntimeClient runtimeClient) throws IOException {

ByteArrayOutputStream payload = new ByteArrayOutputStream(1024);
JacksonFactory.getInstance().getSerializer(Failure.class).toJson(failure, payload);
runtimeClient.postInitError(payload.toByteArray(), failure.getErrorType());
}

static int reportRestoreError(final Failure failure,
final LambdaRuntimeClient runtimeClient) throws IOException {

ByteArrayOutputStream payload = new ByteArrayOutputStream(1024);
JacksonFactory.getInstance().getSerializer(Failure.class).toJson(failure, payload);
return runtimeClient.postRestoreError(payload.toByteArray(), failure.getErrorType());
}

private static PojoSerializer<XRayErrorCause> xRayErrorCauseSerializer;

/**
* @param throwable throwable to convert
* @return json as string expected by XRay's web console. On conversion failure, returns null.
*/
private static String serializeAsXRayJson(Throwable throwable) {
try {
final OutputStream outputStream = new ByteArrayOutputStream();
final XRayErrorCause cause = new XRayErrorCause(throwable);
if (xRayErrorCauseSerializer == null) {
xRayErrorCauseSerializer = JacksonFactory.getInstance().getSerializer(XRayErrorCause.class);
}
xRayErrorCauseSerializer.toJson(cause, outputStream);
return outputStream.toString();
} catch (Exception e) {
return null;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ private static String pathToClassName(final String path) {
private static void loadClass(String name) {
try {
Class.forName(name, true, SYSTEM_CLASS_LOADER);
System.out.println("Loaded " + name);
} catch (ClassNotFoundException e) {
System.err.println("[WARN] Failed to load " + name + ": " + e.getMessage());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import com.amazonaws.services.lambda.runtime.api.client.api.LambdaCognitoIdentity;
import com.amazonaws.services.lambda.runtime.api.client.api.LambdaContext;
import com.amazonaws.services.lambda.runtime.api.client.logging.LambdaContextLogger;
import com.amazonaws.services.lambda.runtime.api.client.runtimeapi.InvocationRequest;
import com.amazonaws.services.lambda.runtime.api.client.runtimeapi.dto.InvocationRequest;
import com.amazonaws.services.lambda.runtime.api.client.util.UnsafeUtil;
import com.amazonaws.services.lambda.runtime.serialization.PojoSerializer;
import com.amazonaws.services.lambda.runtime.serialization.events.LambdaEventSerializers;
Expand All @@ -22,6 +22,7 @@
import com.amazonaws.services.lambda.runtime.serialization.util.Functions;
import com.amazonaws.services.lambda.runtime.serialization.util.ReflectUtil;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
Expand Down Expand Up @@ -902,7 +903,8 @@ public ByteArrayOutputStream call(InvocationRequest request) throws Error, Excep
}
}

handler.handleRequest(request.getContentAsStream(), output, context);
ByteArrayInputStream bais = new ByteArrayInputStream(request.getContent());
handler.handleRequest(bais, output, context);
return output;
}
};
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ public class LambdaEnvironment {
public static final String LAMBDA_LOG_FORMAT = ENV_READER.getEnvOrDefault(AWS_LAMBDA_LOG_FORMAT, "TEXT");
public static final String FUNCTION_NAME = ENV_READER.getEnv(AWS_LAMBDA_FUNCTION_NAME);
public static final String FUNCTION_VERSION = ENV_READER.getEnv(AWS_LAMBDA_FUNCTION_VERSION);
public static final String RUNTIME_API = ENV_READER.getEnv(AWS_LAMBDA_RUNTIME_API);
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

package com.amazonaws.services.lambda.runtime.api.client;

import com.amazonaws.services.lambda.runtime.api.client.runtimeapi.InvocationRequest;
import com.amazonaws.services.lambda.runtime.api.client.runtimeapi.dto.InvocationRequest;

import java.io.ByteArrayOutputStream;

Expand Down
Loading
Loading