Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure caches are not used unsafely #10691

Merged
merged 8 commits into from
Jan 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions .mvn/modernizer/violations.xml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,22 @@
<comment>Prefer Math.toIntExact(long)</comment>
</violation>

<violation>
<name>com/google/common/cache/CacheBuilder.build:()Lcom/google/common/cache/Cache;</name>
<version>1.8</version>
<comment>Guava Cache has concurrency issues around invalidation and ongoing loads. Use EvictableCache, EvictableLoadingCache, or SafeCaches to build caches.
See https://github.com/trinodb/trino/issues/10512 for more information and see https://github.com/trinodb/trino/issues/10512#issuecomment-1016221168
for why Caffeine does not solve the problem.</comment>
</violation>

<violation>
<name>com/google/common/cache/CacheBuilder.build:(Lcom/google/common/cache/CacheLoader;)Lcom/google/common/cache/LoadingCache;</name>
<version>1.8</version>
<comment>Guava LoadingCache has concurrency issues around invalidation and ongoing loads. Use EvictableCache, EvictableLoadingCache, or SafeCaches to build caches.
See https://github.com/trinodb/trino/issues/10512 for more information and see https://github.com/trinodb/trino/issues/10512#issuecomment-1016221168
for why Caffeine does not solve the problem.</comment>
</violation>

<violation>
<name>org/testng/Assert.assertEquals:(Ljava/lang/Iterable;Ljava/lang/Iterable;)V</name>
<version>1.8</version>
Expand Down
5 changes: 5 additions & 0 deletions client/trino-cli/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@
<artifactId>antlr4-runtime</artifactId>
</dependency>

<dependency>
<groupId>org.gaul</groupId>
<artifactId>modernizer-maven-annotations</artifactId>
</dependency>

<dependency>
<groupId>org.jline</groupId>
<artifactId>jline-reader</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import io.trino.client.QueryData;
import io.trino.client.StatementClient;
import org.gaul.modernizer_maven_annotations.SuppressModernizer;
import org.jline.reader.Candidate;
import org.jline.reader.Completer;
import org.jline.reader.LineReader;
Expand Down Expand Up @@ -51,12 +52,21 @@ public TableNameCompleter(QueryRunner queryRunner)
{
this.queryRunner = requireNonNull(queryRunner, "queryRunner session was null!");

tableCache = CacheBuilder.newBuilder()
.refreshAfterWrite(RELOAD_TIME_MINUTES, TimeUnit.MINUTES)
.build(asyncReloading(CacheLoader.from(this::listTables), executor));
tableCache = buildUnsafeCache(
CacheBuilder.newBuilder()
.refreshAfterWrite(RELOAD_TIME_MINUTES, TimeUnit.MINUTES),
asyncReloading(CacheLoader.from(this::listTables), executor));

functionCache = CacheBuilder.newBuilder()
.build(asyncReloading(CacheLoader.from(this::listFunctions), executor));
functionCache = buildUnsafeCache(
CacheBuilder.newBuilder(),
CacheLoader.from(this::listFunctions));
}

// TODO extract safe caches implementations to a new module and use SafeCaches.buildNonEvictableCache hereAsyncCache
@SuppressModernizer
private static <K, V> LoadingCache<K, V> buildUnsafeCache(CacheBuilder<? super K, ? super V> cacheBuilder, CacheLoader<? super K, V> cacheLoader)
{
return cacheBuilder.build(cacheLoader);
}

private List<String> listTables(String schemaName)
Expand Down
5 changes: 5 additions & 0 deletions core/trino-main/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,11 @@
<version>8.4.1</version>
</dependency>

<dependency>
<groupId>org.gaul</groupId>
<artifactId>modernizer-maven-annotations</artifactId>
</dependency>

<dependency>
<groupId>org.jgrapht</groupId>
<artifactId>jgrapht-core</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
*/
package io.trino.execution;

import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import io.airlift.units.Duration;
import io.trino.plugin.base.cache.NonEvictableCache;
import io.trino.spi.ErrorCode;
import io.trino.spi.ErrorCodeSupplier;
import io.trino.spi.ErrorType;
Expand All @@ -29,6 +29,7 @@
import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.base.Preconditions.checkArgument;
import static io.trino.execution.FailureInjector.InjectedFailureType.TASK_FAILURE;
import static io.trino.plugin.base.cache.SafeCaches.buildNonEvictableCache;
import static io.trino.spi.ErrorType.EXTERNAL;
import static io.trino.spi.ErrorType.INSUFFICIENT_RESOURCES;
import static io.trino.spi.ErrorType.INTERNAL_ERROR;
Expand All @@ -40,7 +41,7 @@ public class FailureInjector
{
public static final String FAILURE_INJECTION_MESSAGE = "This error is injected by the failure injection service";

private final Cache<Key, InjectedFailure> failures;
private final NonEvictableCache<Key, InjectedFailure> failures;
private final Duration requestTimeout;

@Inject
Expand All @@ -53,9 +54,8 @@ public FailureInjector(FailureInjectionConfig config)

public FailureInjector(Duration expirationPeriod, Duration requestTimeout)
{
failures = CacheBuilder.newBuilder()
.expireAfterWrite(expirationPeriod.toMillis(), MILLISECONDS)
.build();
failures = buildNonEvictableCache(CacheBuilder.newBuilder()
.expireAfterWrite(expirationPeriod.toMillis(), MILLISECONDS));
this.requestTimeout = requireNonNull(requestTimeout, "requestTimeout is null");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import com.google.common.annotations.VisibleForTesting;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.ImmutableList;
import com.google.common.util.concurrent.ListenableFuture;
import io.airlift.concurrent.ThreadPoolExecutorMBean;
Expand All @@ -40,6 +39,7 @@
import io.trino.memory.MemoryPoolAssignmentsRequest;
import io.trino.memory.NodeMemoryConfig;
import io.trino.memory.QueryContext;
import io.trino.plugin.base.cache.NonEvictableLoadingCache;
import io.trino.spi.QueryId;
import io.trino.spi.TrinoException;
import io.trino.spi.VersionEmbedder;
Expand Down Expand Up @@ -80,6 +80,7 @@
import static io.trino.execution.SqlTask.createSqlTask;
import static io.trino.memory.LocalMemoryManager.GENERAL_POOL;
import static io.trino.memory.LocalMemoryManager.RESERVED_POOL;
import static io.trino.plugin.base.cache.SafeCaches.buildNonEvictableCache;
import static io.trino.spi.StandardErrorCode.ABANDONED_TASK;
import static io.trino.spi.StandardErrorCode.SERVER_SHUTTING_DOWN;
import static java.lang.Math.min;
Expand All @@ -104,8 +105,8 @@ public class SqlTaskManager
private final Duration clientTimeout;

private final LocalMemoryManager localMemoryManager;
private final LoadingCache<QueryId, QueryContext> queryContexts;
private final LoadingCache<TaskId, SqlTask> tasks;
private final NonEvictableLoadingCache<QueryId, QueryContext> queryContexts;
private final NonEvictableLoadingCache<TaskId, SqlTask> tasks;

private final SqlTaskIoStats cachedStats = new SqlTaskIoStats();
private final SqlTaskIoStats finishedTaskStats = new SqlTaskIoStats();
Expand Down Expand Up @@ -163,10 +164,10 @@ public SqlTaskManager(
queryMaxMemoryPerNode = maxQueryMemoryPerNode.toBytes();
queryMaxTotalMemoryPerNode = maxQueryTotalMemoryPerNode.toBytes();

queryContexts = CacheBuilder.newBuilder().weakValues().build(CacheLoader.from(
queryContexts = buildNonEvictableCache(CacheBuilder.newBuilder().weakValues(), CacheLoader.from(
queryId -> createQueryContext(queryId, localMemoryManager, localSpillManager, gcMonitor, maxQueryMemoryPerNode, maxQueryTotalMemoryPerNode, queryMaxMemoryPerTask, maxQuerySpillPerNode)));

tasks = CacheBuilder.newBuilder().build(CacheLoader.from(
tasks = buildNonEvictableCache(CacheBuilder.newBuilder(), CacheLoader.from(
taskId -> createSqlTask(
taskId,
locationFactory.createLocalTaskLocation(taskId),
Expand Down Expand Up @@ -485,7 +486,8 @@ public TaskInfo failTask(TaskId taskId, Throwable failure)
return tasks.getUnchecked(taskId).failed(failure);
}

public void removeOldTasks()
@VisibleForTesting
void removeOldTasks()
{
DateTime oldestAllowedTask = DateTime.now().minus(infoCacheTime.toMillis());
tasks.asMap().values().stream()
Expand All @@ -496,6 +498,8 @@ public void removeOldTasks()
try {
DateTime endTime = taskInfo.getStats().getEndTime();
if (endTime != null && endTime.isBefore(oldestAllowedTask)) {
// The removal here is concurrency safe with respect to any concurrent loads: the cache has no expiration,
// the taskId is in the cache, so there mustn't be an ongoing load.
tasks.asMap().remove(taskId);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would that make sense to assert that task was actually removed to follow what comment above says.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe, but i don't feel comfortable adding assertions in this class. do you want to address this as a followup?

}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
package io.trino.execution.scheduler;

import com.google.common.base.Suppliers;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
Expand All @@ -27,6 +26,7 @@
import io.trino.execution.NodeTaskMap;
import io.trino.metadata.InternalNode;
import io.trino.metadata.InternalNodeManager;
import io.trino.plugin.base.cache.NonEvictableCache;
import io.trino.spi.HostAddress;
import io.trino.spi.SplitWeight;

Expand All @@ -38,23 +38,25 @@
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static io.trino.SystemSessionProperties.getMaxUnacknowledgedSplitsPerTask;
import static io.trino.metadata.NodeState.ACTIVE;
import static io.trino.plugin.base.cache.SafeCaches.buildNonEvictableCache;
import static java.util.Objects.requireNonNull;

public class TopologyAwareNodeSelectorFactory
implements NodeSelectorFactory
{
private static final Logger LOG = Logger.get(TopologyAwareNodeSelectorFactory.class);

private final Cache<InternalNode, Boolean> inaccessibleNodeLogCache = CacheBuilder.newBuilder()
.expireAfterWrite(30, TimeUnit.SECONDS)
.build();
private final NonEvictableCache<InternalNode, Object> inaccessibleNodeLogCache = buildNonEvictableCache(
CacheBuilder.newBuilder()
.expireAfterWrite(30, TimeUnit.SECONDS));

private final NetworkTopology networkTopology;
private final InternalNodeManager nodeManager;
Expand Down Expand Up @@ -164,13 +166,27 @@ private NodeMap createNodeMap(Optional<CatalogName> catalogName)
byHost.put(node.getInternalAddress(), node);
}
catch (UnknownHostException e) {
if (inaccessibleNodeLogCache.getIfPresent(node) == null) {
inaccessibleNodeLogCache.put(node, true);
if (markInaccessibleNode(node)) {
LOG.warn(e, "Unable to resolve host name for node: %s", node);
}
}
}

return new NodeMap(byHostAndPort.build(), byHost.build(), workersByNetworkPath.build(), coordinatorNodeIds);
}

/**
* Returns true if node has been marked as inaccessible, or false if it was known to be inaccessible.
*/
private boolean markInaccessibleNode(InternalNode node)
{
Object marker = new Object();
try {
return inaccessibleNodeLogCache.get(node, () -> marker) == marker;
}
catch (ExecutionException e) {
// impossible
throw new RuntimeException(e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Suppliers;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.google.common.collect.ImmutableSetMultimap;
import io.airlift.log.Logger;
Expand All @@ -25,6 +24,7 @@
import io.trino.execution.NodeTaskMap;
import io.trino.metadata.InternalNode;
import io.trino.metadata.InternalNodeManager;
import io.trino.plugin.base.cache.NonEvictableCache;
import io.trino.spi.HostAddress;
import io.trino.spi.SplitWeight;

Expand All @@ -34,13 +34,15 @@
import java.net.UnknownHostException;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static io.trino.SystemSessionProperties.getMaxUnacknowledgedSplitsPerTask;
import static io.trino.metadata.NodeState.ACTIVE;
import static io.trino.plugin.base.cache.SafeCaches.buildNonEvictableCache;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static java.util.concurrent.TimeUnit.SECONDS;
Expand All @@ -50,9 +52,9 @@ public class UniformNodeSelectorFactory
{
private static final Logger LOG = Logger.get(UniformNodeSelectorFactory.class);

private final Cache<InternalNode, Boolean> inaccessibleNodeLogCache = CacheBuilder.newBuilder()
.expireAfterWrite(30, TimeUnit.SECONDS)
.build();
private final NonEvictableCache<InternalNode, Object> inaccessibleNodeLogCache = buildNonEvictableCache(
CacheBuilder.newBuilder()
.expireAfterWrite(30, TimeUnit.SECONDS));

private final InternalNodeManager nodeManager;
private final int minCandidates;
Expand Down Expand Up @@ -143,13 +145,27 @@ private NodeMap createNodeMap(Optional<CatalogName> catalogName)
byHost.put(node.getInternalAddress(), node);
}
catch (UnknownHostException e) {
if (inaccessibleNodeLogCache.getIfPresent(node) == null) {
inaccessibleNodeLogCache.put(node, true);
if (markInaccessibleNode(node)) {
LOG.warn(e, "Unable to resolve host name for node: %s", node);
}
}
}

return new NodeMap(byHostAndPort.build(), byHost.build(), ImmutableSetMultimap.of(), coordinatorNodeIds);
}

/**
* Returns true if node has been marked as inaccessible, or false if it was known to be inaccessible.
*/
private boolean markInaccessibleNode(InternalNode node)
{
Object marker = new Object();
try {
return inaccessibleNodeLogCache.get(node, () -> marker) == marker;
}
catch (ExecutionException e) {
// impossible
throw new RuntimeException(e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@
import com.fasterxml.jackson.databind.ser.BeanSerializerFactory;
import com.fasterxml.jackson.databind.ser.std.StdSerializer;
import com.fasterxml.jackson.databind.type.TypeFactory;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import io.trino.plugin.base.cache.NonEvictableCache;
import io.trino.plugin.base.cache.SafeCaches;

import java.io.IOException;
import java.util.concurrent.ExecutionException;
Expand Down Expand Up @@ -92,7 +93,7 @@ private static class InternalTypeSerializer<T>
extends StdSerializer<T>
{
private final TypeSerializer typeSerializer;
private final Cache<Class<?>, JsonSerializer<T>> serializerCache = CacheBuilder.newBuilder().build();
private final NonEvictableCache<Class<?>, JsonSerializer<T>> serializerCache = SafeCaches.buildNonEvictableCache(CacheBuilder.newBuilder());

public InternalTypeSerializer(Class<T> baseClass, TypeIdResolver typeIdResolver)
{
Expand Down
Loading