Skip to content

Commit

Permalink
refactor(controller): refactor storage env related code (#1249)
Browse files Browse the repository at this point in the history
* add minio sdk, refactor storage env related code

* update default storage to minio
  • Loading branch information
anda-ren authored Sep 20, 2022
1 parent a7e5f33 commit b7a536c
Show file tree
Hide file tree
Showing 32 changed files with 859 additions and 220 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package ai.starwhale.mlops.api.protocol.swds;

import ai.starwhale.mlops.api.protocol.StorageFileVo;
import ai.starwhale.mlops.storage.fs.FileStorageEnv;
import ai.starwhale.mlops.storage.env.StorageEnv;
import com.fasterxml.jackson.annotation.JsonProperty;
import io.swagger.v3.oas.annotations.media.Schema;
import java.io.Serializable;
Expand Down Expand Up @@ -51,11 +51,6 @@ public class SwDatasetInfoVo implements Serializable {
*/
String indexTable;

/**
* the necessary information to access to file storages key: storage name value: envs
*/
Map<String, FileStorageEnv> fileStorageEnvs;

@JsonProperty("versionTag")
private String versionTag;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
import ai.starwhale.mlops.exception.SwValidationException.ValidSubject;
import ai.starwhale.mlops.exception.api.StarwhaleApiException;
import ai.starwhale.mlops.storage.configuration.StorageProperties;
import ai.starwhale.mlops.storage.fs.FileStorageEnv;
import ai.starwhale.mlops.storage.env.StorageEnv;
import cn.hutool.core.util.StrUtil;
import com.github.pagehelper.PageHelper;
import com.github.pagehelper.PageInfo;
Expand Down Expand Up @@ -107,9 +107,6 @@ public class SwDatasetService {
@Resource
private UserService userService;

@Resource
private StorageProperties storageProperties;

@Resource
private DsFileGetter dsFileGetter;

Expand Down Expand Up @@ -180,9 +177,6 @@ private SwDatasetInfoVo toSwDatasetInfoVo(SwDatasetEntity ds, SwDatasetVersionEn
try {
String storagePath = versionEntity.getStoragePath();
List<StorageFileVo> collect = storageService.listStorageFile(storagePath);
Map<String, FileStorageEnv> fileStorageEnvs = storageProperties.toFileStorageEnvs();
fileStorageEnvs.values().forEach(fileStorageEnv -> fileStorageEnv.add(FileStorageEnv.ENV_KEY_PREFIX,
versionEntity.getStoragePath()));
return SwDatasetInfoVo.builder()
.id(idConvertor.convert(ds.getId()))
.name(ds.getDatasetName())
Expand All @@ -191,7 +185,6 @@ private SwDatasetInfoVo toSwDatasetInfoVo(SwDatasetEntity ds, SwDatasetVersionEn
.versionTag(versionEntity.getVersionTag())
.versionMeta(versionEntity.getVersionMeta())
.createdTime(localDateTimeConvertor.convert(versionEntity.getCreatedTime()))
.fileStorageEnvs(fileStorageEnvs)
.indexTable(versionEntity.getIndexTable())
.files(collect)
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package ai.starwhale.mlops.domain.swds.bo;

import ai.starwhale.mlops.storage.fs.FileStorageEnv;
import ai.starwhale.mlops.storage.env.StorageEnv;
import java.util.Map;
import lombok.AllArgsConstructor;
import lombok.Builder;
Expand Down Expand Up @@ -65,5 +65,5 @@ public class SwDataSet {
/**
* the necessary information to access to file storages key: storage name value: envs
*/
Map<String, FileStorageEnv> fileStorageEnvs;
Map<String, StorageEnv> fileStorageEnvs;
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,32 +17,32 @@
package ai.starwhale.mlops.domain.swds.converter;

import ai.starwhale.mlops.domain.swds.bo.SwDataSet;
import ai.starwhale.mlops.domain.swds.objectstore.StorageAuths;
import ai.starwhale.mlops.domain.swds.po.SwDatasetVersionEntity;
import ai.starwhale.mlops.storage.configuration.StorageProperties;
import ai.starwhale.mlops.storage.fs.FileStorageEnv;
import ai.starwhale.mlops.storage.env.StorageEnv;
import ai.starwhale.mlops.storage.env.StorageEnvsPropertiesConverter;
import ai.starwhale.mlops.storage.env.UserStorageAuthEnv;
import java.util.Map;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;

@Component
public class SwdsBoConverter {

final StorageProperties storageProperties;
final StorageEnvsPropertiesConverter storageEnvsPropertiesConverter;

public SwdsBoConverter(StorageProperties storageProperties) {
this.storageProperties = storageProperties;
public SwdsBoConverter(StorageEnvsPropertiesConverter storageEnvsPropertiesConverter) {
this.storageEnvsPropertiesConverter = storageEnvsPropertiesConverter;
}

public SwDataSet fromEntity(SwDatasetVersionEntity swDatasetVersionEntity) {
Map<String, FileStorageEnv> fileStorageEnvs;
Map<String, StorageEnv> fileStorageEnvs;
if (StringUtils.hasText(swDatasetVersionEntity.getStorageAuths())) {
StorageAuths storageAuths = new StorageAuths(swDatasetVersionEntity.getStorageAuths());
UserStorageAuthEnv storageAuths = new UserStorageAuthEnv(swDatasetVersionEntity.getStorageAuths());
fileStorageEnvs = storageAuths.allEnvs();
} else {
fileStorageEnvs = storageProperties.toFileStorageEnvs();
fileStorageEnvs = storageEnvsPropertiesConverter.propertiesToEnvs();
}
fileStorageEnvs.values().forEach(fileStorageEnv -> fileStorageEnv.add(FileStorageEnv.ENV_KEY_PREFIX,
fileStorageEnvs.values().forEach(fileStorageEnv -> fileStorageEnv.add(StorageEnv.ENV_KEY_PREFIX,
swDatasetVersionEntity.getStoragePath()));
return SwDataSet.builder()
.id(swDatasetVersionEntity.getId())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import ai.starwhale.mlops.exception.SwProcessException.ErrorType;
import ai.starwhale.mlops.storage.StorageAccessService;
import ai.starwhale.mlops.storage.StorageObjectInfo;
import ai.starwhale.mlops.storage.StorageUri;
import java.io.IOException;
import java.io.InputStream;
import lombok.extern.slf4j.Slf4j;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,10 @@
import ai.starwhale.mlops.exception.SwValidationException;
import ai.starwhale.mlops.exception.SwValidationException.ValidSubject;
import ai.starwhale.mlops.storage.StorageAccessService;
import ai.starwhale.mlops.storage.aliyun.StorageAccessServiceAliyun;
import ai.starwhale.mlops.storage.fs.FileStorageEnv;
import ai.starwhale.mlops.storage.s3.S3Config;
import ai.starwhale.mlops.storage.s3.StorageAccessServiceS3;
import java.util.Map;
import ai.starwhale.mlops.storage.StorageUri;
import ai.starwhale.mlops.storage.env.StorageEnv;
import ai.starwhale.mlops.storage.env.UserStorageAccessServiceBuilder;
import ai.starwhale.mlops.storage.env.UserStorageAuthEnv;
import java.util.concurrent.ConcurrentHashMap;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;
Expand All @@ -37,16 +36,20 @@ public class StorageAccessParser {

final SwDatasetVersionMapper swDatasetVersionMapper;

final UserStorageAccessServiceBuilder userStorageAccessServiceBuilder;

ConcurrentHashMap<String, StorageAccessService> storageAccessServicePool = new ConcurrentHashMap<>();

public StorageAccessParser(StorageAccessService defaultStorageAccessService,
SwDatasetVersionMapper swDatasetVersionMapper) {
SwDatasetVersionMapper swDatasetVersionMapper,
UserStorageAccessServiceBuilder userStorageAccessServiceBuilder) {
this.defaultStorageAccessService = defaultStorageAccessService;
this.swDatasetVersionMapper = swDatasetVersionMapper;
this.userStorageAccessServiceBuilder = userStorageAccessServiceBuilder;
}

public StorageAccessService getStorageAccessServiceFromAuth(Long datasetId, String uri,
String authName) {
String authName) {
if (StringUtils.hasText(authName)) {
authName = authName.toUpperCase(); // env vars are uppercase always
}
Expand All @@ -62,73 +65,24 @@ public StorageAccessService getStorageAccessServiceFromAuth(Long datasetId, Stri
return defaultStorageAccessService;
}

StorageAuths storageAuths = new StorageAuths(storageAuthsText);
FileStorageEnv env = storageAuths.getEnv(authName);
UserStorageAuthEnv storageAuths = new UserStorageAuthEnv(storageAuthsText);
StorageEnv env = storageAuths.getEnv(authName);
if (null == env) {
return defaultStorageAccessService;
}

switch (env.getEnvType()) {
case S3:
var s3 = new StorageAccessServiceS3(env2S3Config(new StorageUri(uri), env, authName));
storageAccessServicePool.putIfAbsent(formatKey(datasetId, authName), s3);
return s3;
case ALIYUN:
var aliyun = new StorageAccessServiceAliyun(env2S3Config(new StorageUri(uri), env, authName));
storageAccessServicePool.putIfAbsent(formatKey(datasetId, authName), aliyun);
return aliyun;
default:
throw new SwValidationException(ValidSubject.SWDS).tip(
"file system not supported yet: " + env.getEnvType());
StorageAccessService storageAccessService = userStorageAccessServiceBuilder.build(env, new StorageUri(uri),
authName);
if (null == storageAccessService) {
throw new SwValidationException(ValidSubject.SWDS).tip(
"file system not supported yet: " + env.getEnvType());
}
storageAccessServicePool.putIfAbsent(formatKey(datasetId, authName), storageAccessService);
return storageAccessService;
}

String formatKey(Long datasetId, String authName) {
return datasetId.toString() + authName;
}


static final String KEY_BUCKET = "USER.S3.%sBUCKET";
static final String KEY_REGION = "USER.S3.%sREGION";
static final String KEY_ENDPOINT = "USER.S3.%sENDPOINT";
static final String KEY_SECRET = "USER.S3.%sSECRET";
static final String KEY_ACCESS_KEY = "USER.S3.%sACCESS_KEY";

S3Config env2S3Config(StorageUri storageUri, FileStorageEnv env, String authName) {
if (StringUtils.hasText(authName)) {
authName = authName + ".";
} else {
authName = "";
}
authName = authName.toUpperCase();
Map<String, String> envs = env.getEnvs();
String bucket = StringUtils.hasText(storageUri.getBucket()) ? storageUri.getBucket()
: envs.get(String.format(KEY_BUCKET, authName));
String accessKey = StringUtils.hasText(storageUri.getUsername()) ? storageUri.getUsername()
: envs.get(String.format(KEY_ACCESS_KEY, authName));
String accessSecret =
StringUtils.hasText(storageUri.getPassword()) ? storageUri.getPassword()
: envs.get(String.format(KEY_SECRET, authName));
String endpoint = StringUtils.hasText(storageUri.getHost()) ? buildEndPoint(storageUri)
: envs.get(String.format(KEY_ENDPOINT, authName));
return S3Config.builder()
.bucket(bucket)
.accessKey(accessKey)
.secretKey(accessSecret)
.region(envs.get(String.format(KEY_REGION, authName)))
.endpoint(endpoint)
.build();
}

private String buildEndPoint(StorageUri storageUri) {
if (null == storageUri.getPort() || 80 == storageUri.getPort()) {
return "http://" + storageUri.getHost();
} else if (443 == storageUri.getPort()) {
return "https://" + storageUri.getHost();
} else {
return "http://" + storageUri.getHost() + ":" + storageUri.getPort();
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,8 @@
import ai.starwhale.mlops.domain.task.status.watchers.TaskWatcherForSchedule;
import ai.starwhale.mlops.schedule.SwTaskScheduler;
import ai.starwhale.mlops.storage.configuration.StorageProperties;
import ai.starwhale.mlops.storage.fs.AliyunEnv;
import ai.starwhale.mlops.storage.fs.BotoS3Config;
import ai.starwhale.mlops.storage.fs.FileStorageEnv;
import ai.starwhale.mlops.storage.env.StorageEnv;
import ai.starwhale.mlops.storage.env.StorageEnvsPropertiesConverter;
import cn.hutool.json.JSONUtil;
import io.kubernetes.client.informer.ResourceEventHandler;
import io.kubernetes.client.openapi.ApiException;
Expand Down Expand Up @@ -75,6 +74,7 @@ public class K8sTaskScheduler implements SwTaskScheduler {
final ResourceEventHandler<V1Job> eventHandlerJob;
final ResourceEventHandler<V1Node> eventHandlerNode;
final String instanceUri;
final StorageEnvsPropertiesConverter storageEnvsPropertiesConverter;

public K8sTaskScheduler(K8sClient k8sClient,
StorageProperties storageProperties,
Expand All @@ -83,7 +83,8 @@ public K8sTaskScheduler(K8sClient k8sClient,
K8sResourcePoolConverter resourcePoolConverter,
K8sJobTemplate k8sJobTemplate,
ResourceEventHandler<V1Job> eventHandlerJob,
ResourceEventHandler<V1Node> eventHandlerNode, @Value("${sw.instance-uri}") String instanceUri) {
ResourceEventHandler<V1Node> eventHandlerNode, @Value("${sw.instance-uri}") String instanceUri,
StorageEnvsPropertiesConverter storageEnvsPropertiesConverter) {
this.k8sClient = k8sClient;
this.storageProperties = storageProperties;
this.jobTokenConfig = jobTokenConfig;
Expand All @@ -93,6 +94,7 @@ public K8sTaskScheduler(K8sClient k8sClient,
this.eventHandlerJob = eventHandlerJob;
this.eventHandlerNode = eventHandlerNode;
this.instanceUri = instanceUri;
this.storageEnvsPropertiesConverter = storageEnvsPropertiesConverter;
}

@Override
Expand Down Expand Up @@ -190,7 +192,7 @@ private Map<String, String> buildCoreContainerEnvs(Task task) {
swDataSets.forEach(ds -> ds.getFileStorageEnvs().values()
.forEach(fileStorageEnv -> coreContainerEnvs.putAll(fileStorageEnv.getEnvs())));

coreContainerEnvs.put(FileStorageEnv.ENV_KEY_PREFIX, swDataSet.getPath());
coreContainerEnvs.put(StorageEnv.ENV_KEY_PREFIX, swDataSet.getPath());

// datastore env
coreContainerEnvs.put("SW_TOKEN", jobTokenConfig.getToken());
Expand All @@ -204,9 +206,9 @@ private Map<String, String> getInitContainerEnvs(Task task) {
Job swJob = task.getStep().getJob();
JobRuntime jobRuntime = swJob.getJobRuntime();
Map<String, String> initContainerEnvs = new HashMap<>();
Map<String, FileStorageEnv> fileStorageEnvs = storageProperties.toFileStorageEnvs();
Map<String, StorageEnv> fileStorageEnvs = storageEnvsPropertiesConverter.propertiesToEnvs();
// Ignore keys, we use only one key by now
fileStorageEnvs.values().forEach((FileStorageEnv env) -> initContainerEnvs.putAll(env.getEnvs()));
fileStorageEnvs.values().forEach((StorageEnv env) -> initContainerEnvs.putAll(env.getEnvs()));

List<String> downloads = new ArrayList<>();
String prefix = "s3://" + storageProperties.getS3Config().getBucket() + "/";
Expand Down
2 changes: 1 addition & 1 deletion server/controller/src/main/resources/application.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ sw:
host-path-for-cache: ${SW_K8S_HOST_PATH_FOR_CACHE:/mnt/data}
job-template-path: ${SW_K8S_JOB_TEMPLATE_PATH:}
storage:
type: ${SW_STORAGE_TYPE:s3}
type: ${SW_STORAGE_TYPE:minio}
path-prefix: ${SW_STORAGE_PREFIX:starwhale}
fs-root-dir: ${SW_STORAGE_FS_ROOT_DIR:/usr/local/starwhale}
s3-config:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,16 @@

package ai.starwhale.mlops.domain.swds.objectstore;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import ai.starwhale.mlops.domain.swds.mapper.SwDatasetVersionMapper;
import ai.starwhale.mlops.domain.swds.po.SwDatasetVersionEntity;
import ai.starwhale.mlops.storage.StorageAccessService;
import ai.starwhale.mlops.storage.fs.FileStorageEnv;
import ai.starwhale.mlops.storage.fs.FileStorageEnv.FileSystemEnvType;
import ai.starwhale.mlops.storage.s3.S3Config;
import ai.starwhale.mlops.storage.env.UserStorageAccessServiceBuilder;
import ai.starwhale.mlops.storage.minio.StorageAccessServiceMinio;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

Expand All @@ -39,38 +40,39 @@ public class StorageAccessParserTest {
+ "USER.S3.MNIST.SECRET=\n"
+ "USER.S3.MYNAME.ACCESS_KEY=access_key1\n";

StorageAccessService defaultStorageAccessService = mock(StorageAccessService.class);

SwDatasetVersionMapper swDatasetVersionMapper = mock(SwDatasetVersionMapper.class);


@Test
public void testDefaultService() {
StorageAccessService defaultStorageAccessService = mock(StorageAccessService.class);
SwDatasetVersionMapper swDatasetVersionMapper = mock(SwDatasetVersionMapper.class);
when(swDatasetVersionMapper.getVersionById(1L)).thenReturn(
SwDatasetVersionEntity.builder().id(1L).storageAuths(auths).build());
when(swDatasetVersionMapper.getVersionById(2L)).thenReturn(
SwDatasetVersionEntity.builder().id(2L).storageAuths("").build());
StorageAccessParser storageAccessParser = new StorageAccessParser(defaultStorageAccessService,
swDatasetVersionMapper);
swDatasetVersionMapper, null);
StorageAccessService storageAccessService = storageAccessParser.getStorageAccessServiceFromAuth(2L,
"s3://renyanda/bdc/xyf",
"myname");
Assertions.assertEquals(defaultStorageAccessService, storageAccessService);
}

@Test
public void testEnv2S3Config() {
StorageAccessParser storageAccessParser = new StorageAccessParser(null, null);
S3Config s3Config = storageAccessParser.env2S3Config(
new StorageUri("s3://renyanda/bdc/xyf"), new FileStorageEnv(
FileSystemEnvType.S3)
.add("USER.S3.MYTEST.ENDPOINT", "endpoint")
.add("USER.S3.URTEST.ENDPOINT", "EDP")
.add("USER.S4.mytest.endpoint", "dpd")
.add("USER.S3.mytest.SECRET", "SCret")
.add("USER.S3.mytest.ACCESS_KEY", "ack")
.add("USER.S3.mytest.BUCKET", "bkt")
.add("USER.S3.MYTEST.REGION", "region"), "mytest");
Assertions.assertEquals("renyanda", s3Config.getBucket());
Assertions.assertEquals("ack", s3Config.getAccessKey());
Assertions.assertEquals("SCret", s3Config.getSecretKey());
Assertions.assertEquals("region", s3Config.getRegion());
public void testCache() {
when(swDatasetVersionMapper.getVersionById(1L)).thenReturn(
SwDatasetVersionEntity.builder().id(1L).storageAuths(auths).build());
UserStorageAccessServiceBuilder userStorageAccessServiceBuilder = mock(UserStorageAccessServiceBuilder.class);
StorageAccessServiceMinio storageAccessServiceMinio = mock(StorageAccessServiceMinio.class);
when(userStorageAccessServiceBuilder.build(any(), any(), any())).thenReturn(storageAccessServiceMinio);
StorageAccessParser storageAccessParser = new StorageAccessParser(defaultStorageAccessService,
swDatasetVersionMapper, userStorageAccessServiceBuilder);

StorageAccessService myname = storageAccessParser.getStorageAccessServiceFromAuth(1L, "s3://renyanda/bdc/xyf",
"myname");
Assertions.assertEquals(storageAccessServiceMinio, myname);
myname = storageAccessParser.getStorageAccessServiceFromAuth(1L, "s3://renyanda/bdc/xyfzzz",
"myname");
Assertions.assertEquals(storageAccessServiceMinio, myname);
verify(userStorageAccessServiceBuilder).build(any(), any(), any());
}
}
Loading

0 comments on commit b7a536c

Please sign in to comment.