Skip to content

Commit

Permalink
Refactor type check cache
Browse files Browse the repository at this point in the history
  • Loading branch information
heshanpadmasiri committed Feb 8, 2025
1 parent 9b9388a commit d4ae000
Show file tree
Hide file tree
Showing 18 changed files with 273 additions and 177 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@
import io.ballerina.runtime.api.types.FiniteType;
import io.ballerina.runtime.api.types.JsonType;
import io.ballerina.runtime.api.types.MapType;
import io.ballerina.runtime.api.types.NamedTypeIdentifier;
import io.ballerina.runtime.api.types.ObjectType;
import io.ballerina.runtime.api.types.RecordType;
import io.ballerina.runtime.api.types.StreamType;
import io.ballerina.runtime.api.types.TableType;
import io.ballerina.runtime.api.types.TupleType;
import io.ballerina.runtime.api.types.Type;
import io.ballerina.runtime.api.types.TypeIdentifier;
import io.ballerina.runtime.api.types.UnionType;
import io.ballerina.runtime.api.types.XmlType;
import io.ballerina.runtime.internal.TypeCheckLogger;
Expand Down Expand Up @@ -542,7 +544,7 @@ private static BRecordType registeredRecordType(String typeName, Module pkg) {
if (typeName == null || pkg == null) {
return null;
}
return registeredRecordTypes.get(new TypeIdentifier(typeName, pkg));
return registeredRecordTypes.get(new NamedTypeIdentifier(pkg, typeName));
}

public static void registerRecordType(BRecordType recordType) {
Expand All @@ -554,7 +556,7 @@ public static void registerRecordType(BRecordType recordType) {
if (name.contains("$anon")) {
return;
}
TypeIdentifier typeIdentifier = new TypeIdentifier(name, pkg);
TypeIdentifier typeIdentifier = new NamedTypeIdentifier(pkg, name);
registeredRecordTypes.put(typeIdentifier, recordType);
}

Expand All @@ -571,14 +573,6 @@ void put(TypeIdentifier identifier, BRecordType value) {
}
}

public record TypeIdentifier(String typeName, Module pkg) {

public TypeIdentifier {
assert typeName != null;
assert pkg != null;
}
}

private static final class MapTypeCache {

private static final Map<Type, MapType> cache = new ConcurrentHashMap<>();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package io.ballerina.runtime.api.types;

import io.ballerina.runtime.api.Module;

public record NamedTypeIdentifier(Module pkg, String name) implements TypeIdentifier {

public NamedTypeIdentifier {
if (pkg == null) {
throw new IllegalArgumentException("Package cannot be null");
}
if (name == null) {
throw new IllegalArgumentException("Name cannot be null");
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package io.ballerina.runtime.api.types;

public interface TypeIdentifier {

}
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,6 @@ public interface CacheableTypeDescriptor extends Type {
* @param result Result of the type check
*/
void cacheTypeCheckResult(CacheableTypeDescriptor other, boolean result);

TypeId typeId();
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import io.ballerina.runtime.internal.types.semtype.MappingAtomicType;

import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
Expand All @@ -46,24 +45,9 @@ public final class Context {
public final Map<Bdd, BddMemo> listMemo = new WeakHashMap<>();
public final Map<Bdd, BddMemo> mappingMemo = new WeakHashMap<>();
public final Map<Bdd, BddMemo> functionMemo = new WeakHashMap<>();
private static final int MAX_CACHE_SIZE = 100;
private final Map<CacheableTypeDescriptor, TypeCheckCache<CacheableTypeDescriptor>> typeCheckCacheMemo;

private Context(Env env) {
this.env = env;
this.typeCheckCacheMemo = createTypeCheckCacheMemo();
}

private static Map<CacheableTypeDescriptor, TypeCheckCache<CacheableTypeDescriptor>> createTypeCheckCacheMemo() {
// This is fine since this map is not going to get leaked out of the context and
// context is unique to a thread. So there will be no concurrent modifications
return new LinkedHashMap<>(MAX_CACHE_SIZE, 1f, true) {
@Override
protected boolean removeEldestEntry(
Map.Entry<CacheableTypeDescriptor, TypeCheckCache<CacheableTypeDescriptor>> eldest) {
return size() > MAX_CACHE_SIZE;
}
};
}

public static Context from(Env env) {
Expand Down Expand Up @@ -140,12 +124,4 @@ public MappingAtomicType mappingAtomType(Atom atom) {
public FunctionAtomicType functionAtomicType(Atom atom) {
return env.functionAtomType(atom);
}

public TypeCheckCache<CacheableTypeDescriptor> getTypeCheckCache(CacheableTypeDescriptor typeDescriptor) {
return typeCheckCacheMemo.computeIfAbsent(typeDescriptor, TypeCheckCache::new);
}

enum Phase {
INIT, TYPE_RESOLUTION, TYPE_CHECKING
}
}
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
package io.ballerina.runtime.api.types.semtype;

import io.ballerina.runtime.api.types.Type;

import java.util.Map;
import java.util.Optional;
import java.util.WeakHashMap;
import java.util.concurrent.ConcurrentHashMap;

/**
* Generalized implementation of type check result cache. It is okay to access
Expand All @@ -16,26 +14,27 @@
* @param <T> Type of the type descriptor which owns this cache
* @since 2201.11.0
*/
public class TypeCheckCache<T extends Type> {
public class TypeCheckCache {

// Not synchronizing this should be fine since race conditions don't lead to inconsistent results. (i.e. results
// of doing multiple type checks are agnostic to the order of execution). Data races shouldn't lead to tearing in
// 64-bit JVMs.
private final Map<T, Boolean> cachedResults = new WeakHashMap<>();
private final T owner;
private final Map<TypeId, Boolean> cachedResults = new ConcurrentHashMap<>();
private final TypeId ownerId;

public TypeCheckCache(T owner) {
this.owner = owner;
public TypeCheckCache(CacheableTypeDescriptor owner) {
this.ownerId = owner.typeId();
}

public Optional<Boolean> cachedTypeCheckResult(T other) {
if (other.equals(owner)) {
public Optional<Boolean> cachedTypeCheckResult(CacheableTypeDescriptor other) {
if (other.typeId().equals(ownerId)) {
return Optional.of(true);
}
return Optional.ofNullable(cachedResults.get(other));
return Optional.ofNullable(cachedResults.get(other.typeId()));
}

public void cacheTypeCheckResult(CacheableTypeDescriptor other, boolean result) {
cachedResults.put(other.typeId(), result);
}

public void cacheTypeCheckResult(T other, boolean result) {
cachedResults.put(other, result);
public void reset() {
cachedResults.clear();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package io.ballerina.runtime.api.types.semtype;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

public final class TypeCheckCacheFactory {

private static final Map<TypeId, TypeCheckCache> cacheMap = new ConcurrentHashMap<>();

private TypeCheckCacheFactory() {
}

public static TypeCheckCache get(CacheableTypeDescriptor owner) {
return cacheMap.computeIfAbsent(owner.typeId(), k -> new TypeCheckCache(owner));
}

public static void resetCache() {
cacheMap.forEach((k, v) -> v.reset());
cacheMap.clear();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package io.ballerina.runtime.api.types.semtype;

public interface TypeId {

}
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,7 @@ private static boolean isSubTypeWithCache(Context cx, CacheableTypeDescriptor so
Optional<Boolean> cachedResult = source.cachedTypeCheckResult(cx, target);
logger.typeCheckCachedResult(cx, source, target, cachedResult);
if (cachedResult.isPresent()) {
assert cachedResult.get() == isSubTypeInner(cx, source, target);
// assert cachedResult.get() == isSubTypeInner(cx, source, target);
return cachedResult.get();
}
boolean result = isSubTypeInner(cx, source, target);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,20 @@
import io.ballerina.runtime.api.Module;
import io.ballerina.runtime.api.creators.ErrorCreator;
import io.ballerina.runtime.api.types.IntersectionType;
import io.ballerina.runtime.api.types.NamedTypeIdentifier;
import io.ballerina.runtime.api.types.Type;
import io.ballerina.runtime.api.types.TypeTags;
import io.ballerina.runtime.api.types.semtype.CacheableTypeDescriptor;
import io.ballerina.runtime.api.types.semtype.Context;
import io.ballerina.runtime.api.types.semtype.SemType;
import io.ballerina.runtime.api.types.semtype.TypeCheckCache;
import io.ballerina.runtime.api.types.semtype.TypeCheckCacheFactory;
import io.ballerina.runtime.api.types.semtype.TypeId;
import io.ballerina.runtime.api.utils.StringUtils;
import io.ballerina.runtime.internal.TypeCheckLogger;
import io.ballerina.runtime.internal.TypeChecker;
import io.ballerina.runtime.internal.types.semtype.MutableSemType;
import io.ballerina.runtime.internal.types.semtype.TypeIdFactory;

import java.util.HashSet;
import java.util.Objects;
Expand Down Expand Up @@ -58,8 +62,10 @@ public abstract non-sealed class BType extends SemType
private Type cachedReferredType = null;
private Type cachedImpliedType = null;
private volatile SemType cachedSemType = null;
private volatile TypeCheckCache<CacheableTypeDescriptor> typeCheckCache;
private volatile TypeCheckCache typeCheckCache;
private final ReadWriteLock typeCacheLock = new ReentrantReadWriteLock();
private final ReadWriteLock typeIdLock = new ReentrantReadWriteLock();
private volatile TypeId typeId;

protected BType(String typeName, Module pkg, Class<? extends Object> valueClass) {
this.typeName = typeName;
Expand Down Expand Up @@ -290,7 +296,11 @@ public BType clone() {

@Override
public boolean shouldCache() {
return this.pkg != null && this.typeName != null && !this.typeName.contains("$anon");
return !isAnonType();
}

private boolean isAnonType() {
return this.pkg == null || this.typeName == null || this.typeName.isEmpty() || this.typeName.contains("$anon");
}

@Override
Expand All @@ -314,7 +324,7 @@ private void initializeCacheIfNeeded(Context cx) {
try {
typeCacheLock.writeLock().lock();
if (typeCheckCache == null) {
typeCheckCache = cx.getTypeCheckCache(this);
typeCheckCache = TypeCheckCacheFactory.get(this);
}
} finally {
typeCacheLock.writeLock().unlock();
Expand Down Expand Up @@ -343,4 +353,25 @@ public final boolean isDependentlyTyped(Set<MayBeDependentType> visited) {
protected boolean isDependentlyTypedInner(Set<MayBeDependentType> visited) {
return false;
}

@Override
public TypeId typeId() {
typeIdLock.readLock().lock();
try {
if (typeId != null) {
return typeId;
}
} finally {
typeIdLock.readLock().unlock();
}
typeIdLock.writeLock().lock();
try {
if (typeId == null) {
typeId = TypeIdFactory.getTypeId(new NamedTypeIdentifier(pkg, typeName));
}
return typeId;
} finally {
typeIdLock.writeLock().unlock();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package io.ballerina.runtime.internal.types.semtype;

import io.ballerina.runtime.api.types.TypeIdentifier;
import io.ballerina.runtime.api.types.semtype.TypeId;

import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;

public class TypeIdFactory {

private static final Map<TypeIdentifier, TypeId> typeIds = new HashMap<>();
private static final ReadWriteLock lock = new ReentrantReadWriteLock();
private static long nextId = 0;

public static TypeId getTypeId(TypeIdentifier identifier) {
lock.readLock().lock();
try {
TypeId cached = typeIds.get(identifier);
if (cached != null) {
return cached;
}
} finally {
lock.readLock().unlock();
}
lock.writeLock().lock();
try {
TypeId cached = typeIds.get(identifier);
if (cached != null) {
return cached;
}
TypeId typeId = new IntegerBasedTypeId(nextId++);
typeIds.put(identifier, typeId);
return typeId;
} finally {
lock.writeLock().unlock();
}
}

private record IntegerBasedTypeId(long id) implements TypeId {

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import io.ballerina.projects.repos.FileSystemCache;
import io.ballerina.projects.util.ProjectConstants;
import io.ballerina.projects.util.ProjectUtils;
import io.ballerina.runtime.api.types.semtype.TypeCheckCacheFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.wso2.ballerinalang.compiler.bir.model.BIRNode;
Expand Down Expand Up @@ -89,6 +90,7 @@ public static Project loadProject(String sourceFilePath, BuildOptions buildOptio
}

public static CompileResult compile(String sourceFilePath) {
TypeCheckCacheFactory.resetCache();
Project project = loadProject(sourceFilePath);

Package currentPackage = project.currentPackage();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ public void testConstrainedMapRefTypeCastNegative() {
String errorMsg =
((BMap<String, BString>) ((BError) returns).getDetails()).get(StringUtils.fromString("message"))
.toString();
Assert.assertTrue(errorMsg.startsWith("incompatible types: 'map<Person>' cannot be cast to 'map<int>'"));
Assert.assertTrue(errorMsg.startsWith("incompatible types: 'map<PersonCM>' cannot be cast to 'map<int>'"));

}

Expand Down Expand Up @@ -353,7 +353,8 @@ public void testStructNotEquivalentRuntimeCast() {
String errorMsg =
((BMap<String, BString>) ((BError) returns).getDetails()).get(StringUtils.fromString("message"))
.toString();
Assert.assertTrue(errorMsg.startsWith("incompatible types: 'map<Employee>' cannot be cast to 'map<Person>'"));
Assert.assertTrue(
errorMsg.startsWith("incompatible types: 'map<EmployeeCM>' cannot be cast to 'map<PersonCM>'"));
}

@Test(description = "Test runtime cast for any map to int map.")
Expand All @@ -373,7 +374,7 @@ public void testAnyMapToRefTypeRuntimeCast() {
String errorMsg =
((BMap<String, BString>) ((BError) returns).getDetails()).get(StringUtils.fromString("message"))
.toString();
Assert.assertTrue(errorMsg.startsWith("incompatible types: 'map' cannot be cast to 'map<Employee>'"));
Assert.assertTrue(errorMsg.startsWith("incompatible types: 'map' cannot be cast to 'map<EmployeeCM>'"));
}

@Test(description = "Test struct to map conversion for constrained map.")
Expand Down
Loading

0 comments on commit d4ae000

Please sign in to comment.