Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(controller): Support link auth in controller #999

Merged
merged 4 commits into from
Aug 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
@Builder
@AllArgsConstructor
@NoArgsConstructor
@Deprecated
public class TaskTrigger {

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,23 @@ void pullDS(
@RequestParam(name = "part_name",required = false) String partName,
HttpServletResponse httpResponse);

@Operation(summary = "Pull SWDS uri file contents",
description = "Pull SWDS uri file contents ")
@ApiResponses(value = {@ApiResponse(responseCode = "200", description = "ok")})
@GetMapping(
value = "/project/{projectUrl}/dataset/{datasetUrl}/version/{versionUrl}/link",
produces = MediaType.APPLICATION_OCTET_STREAM_VALUE)
@PreAuthorize("hasAnyRole('OWNER', 'MAINTAINER')")
void pullLinkContent(
@PathVariable(name = "projectUrl") String projectUrl,
@PathVariable(name = "datasetUrl") String datasetUrl,
@PathVariable(name = "versionUrl") String versionUrl,
@Parameter(name = "uri", description = "uri of the link")
@RequestParam(name = "uri",required = true) String uri,
@Parameter(name = "authName", description = "auth name the link used")
@RequestParam(name = "authName",required = false) String authName,
HttpServletResponse httpResponse);


@Operation(summary = "Set the tag of the dataset version")
@ApiResponses(value = {@ApiResponse(responseCode = "200", description = "ok")})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import ai.starwhale.mlops.domain.swds.bo.SWDSVersion;
import ai.starwhale.mlops.domain.swds.bo.SWDSVersionQuery;
import ai.starwhale.mlops.domain.swds.SWDatasetService;
import ai.starwhale.mlops.domain.swds.po.SWDatasetVersionEntity;
import ai.starwhale.mlops.domain.swds.upload.SwdsUploader;
import ai.starwhale.mlops.exception.ApiOperationException;
import ai.starwhale.mlops.exception.SWProcessException;
Expand Down Expand Up @@ -186,6 +187,25 @@ public void pullDS(String projectUrl, String datasetUrl, String versionUrl,
swdsUploader.pull(projectUrl, datasetUrl, versionUrl, partName,httpResponse);
}

@Override
public void pullLinkContent(String projectUrl, String datasetUrl, String versionUrl,
String uri,String authName, HttpServletResponse httpResponse) {
if(!StringUtils.hasText(datasetUrl) || !StringUtils.hasText(versionUrl) ){
throw new StarWhaleApiException(new SWValidationException(ValidSubject.SWDS)
.tip("please provide name and version for the DS "), HttpStatus.BAD_REQUEST);
}
SWDatasetVersionEntity datasetVersionEntity = swDatasetService.query(projectUrl, datasetUrl, versionUrl);
try {
ServletOutputStream outputStream = httpResponse.getOutputStream();
outputStream.write(swDatasetService.dataOf(datasetVersionEntity.getId(),uri,authName));
outputStream.flush();
} catch (IOException e) {
log.error("error write data to response",e);
throw new SWProcessException(ErrorType.NETWORK).tip("error write data to response");
}

}

@Override
public ResponseEntity<ResponseMessage<String>> modifyDatasetVersionInfo(
String projectUrl, String datasetUrl, String versionUrl, SWDSTagRequest swdsTagRequest) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import ai.starwhale.mlops.domain.swds.converter.SWDSVersionConvertor;
import ai.starwhale.mlops.domain.swds.mapper.SWDatasetMapper;
import ai.starwhale.mlops.domain.swds.mapper.SWDatasetVersionMapper;
import ai.starwhale.mlops.domain.swds.objectstore.DSFileGetter;
import ai.starwhale.mlops.domain.swds.po.SWDatasetEntity;
import ai.starwhale.mlops.domain.swds.po.SWDatasetVersionEntity;
import ai.starwhale.mlops.domain.user.UserService;
Expand Down Expand Up @@ -105,6 +106,8 @@ public class SWDatasetService {
@Resource
private StorageProperties storageProperties;

private DSFileGetter dsFileGetter;

private BundleManager bundleManager() {
return new BundleManager(idConvertor, projectManager, swdsManager, swdsManager, ValidSubject.SWDS);
}
Expand Down Expand Up @@ -274,7 +277,7 @@ private List<SWDatasetInfoVO> swDatasetInfoOfDs(SWDatasetEntity ds) {
.map(entity -> toSWDatasetInfoVO(ds, entity)).collect(Collectors.toList());
}

public String query(String projectUrl, String datasetUrl, String versionUrl) {
public SWDatasetVersionEntity query(String projectUrl, String datasetUrl, String versionUrl) {
Long projectId = projectManager.getProjectId(projectUrl);
SWDatasetEntity entity = swdsMapper.findByName(datasetUrl, projectId);
if(null == entity) {
Expand All @@ -284,6 +287,10 @@ public String query(String projectUrl, String datasetUrl, String versionUrl) {
if(null == versionEntity) {
throw new StarWhaleApiException(new SWValidationException(ValidSubject.SWDS), HttpStatus.NOT_FOUND);
}
return versionEntity.getName();
return versionEntity;
}

public byte[] dataOf(Long datasetId,String uri,String authName){
return dsFileGetter.dataOf(datasetId,uri,authName);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@
package ai.starwhale.mlops.domain.swds.converter;

import ai.starwhale.mlops.domain.swds.bo.SWDataSet;
import ai.starwhale.mlops.domain.swds.objectstore.StorageAuths;
import ai.starwhale.mlops.domain.swds.po.SWDatasetVersionEntity;
import ai.starwhale.mlops.storage.configuration.StorageProperties;
import ai.starwhale.mlops.storage.fs.FileStorageEnv;
import java.util.Map;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;

@Component
public class SWDSBOConverter {
Expand All @@ -33,7 +35,13 @@ public SWDSBOConverter(StorageProperties storageProperties) {
}

public SWDataSet fromEntity(SWDatasetVersionEntity swDatasetVersionEntity){
Map<String, FileStorageEnv> fileStorageEnvs = storageProperties.toFileStorageEnvs();
Map<String, FileStorageEnv> fileStorageEnvs;
if(StringUtils.hasText(swDatasetVersionEntity.getStorageAuths())){
StorageAuths storageAuths = new StorageAuths(swDatasetVersionEntity.getStorageAuths());
fileStorageEnvs = storageAuths.allEnvs();
}else {
fileStorageEnvs = storageProperties.toFileStorageEnvs();
}
fileStorageEnvs.values().forEach(fileStorageEnv -> fileStorageEnv.add(FileStorageEnv.ENV_KEY_PREFIX,swDatasetVersionEntity.getStoragePath()));
return SWDataSet.builder()
.id(swDatasetVersionEntity.getId())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,18 @@
* data store helper for data set
*/
@Component
public class DSRHelper {
public class DataStoreTableNameHelper {

public static final String FORMATTER_TABLE_NAME="SW_TABLE_DST_%s_%s";
final int VERSION_PREFIX_CNT = 2;

public String tableNameOf(String name,String version){
return String.format(FORMATTER_TABLE_NAME,name,version);
public static final String FORMATTER_TABLE_NAME_DATASET ="project/%s/dataset/%s/%s/%s/meta";

public static final String FORMATTER_TABLE_NAME_EVAL_RESULTS ="project/%s/eval/%s/results";

public static final String FORMATTER_TABLE_NAME_EVAL_SUMMARY ="project/%s/eval/summary";

public String tableNameOfDataset(String project,String name,String version){
return String.format(FORMATTER_TABLE_NAME_DATASET,project,name,version.substring(0,VERSION_PREFIX_CNT),version);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ public class IndexItem {
@JsonProperty("data_origin")
String dataOrigin;

@JsonProperty("object_store_type")
String object_store_type;

@JsonProperty("auth_name")
String auth_name;

@JsonProperty("label")
String label;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ List<SWDatasetVersionEntity> listVersions(@Param("datasetId")Long datasetId,

int updateStatus(@Param("id")Long id ,@Param("status")Integer status);

int updateStorageAuths(@Param("id")Long id ,@Param("storageAuths")String storageAuths);

int deleteById(@Param("id")Long id);

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright 2022 Starwhale, Inc. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package ai.starwhale.mlops.domain.swds.objectstore;

import ai.starwhale.mlops.exception.SWProcessException;
import ai.starwhale.mlops.exception.SWProcessException.ErrorType;
import ai.starwhale.mlops.storage.StorageAccessService;
import java.io.IOException;
import java.io.InputStream;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;

@Slf4j
@Service
public class DSFileGetter {

final StorageAccessParser storageAccessParser;

public DSFileGetter(StorageAccessParser storageAccessParser) {
this.storageAccessParser = storageAccessParser;
}

public byte[] dataOf(Long datasetId, String uri, String authName) {
StorageAccessService storageAccessService = storageAccessParser.getStorageAccessServiceFromAuth(
datasetId, uri, authName);
try (InputStream inputStream = storageAccessService.get(new StorageUri(uri).getPath())) {
return inputStream.readAllBytes();
} catch (IOException ioException) {
log.error("error while accessing storage ", ioException);
throw new SWProcessException(ErrorType.STORAGE).tip(
String.format("error while accessing storage : %s", ioException.getMessage()));
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
/*
* Copyright 2022 Starwhale, Inc. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package ai.starwhale.mlops.domain.swds.objectstore;

import ai.starwhale.mlops.domain.swds.mapper.SWDatasetVersionMapper;
import ai.starwhale.mlops.domain.swds.po.SWDatasetVersionEntity;
import ai.starwhale.mlops.exception.SWValidationException;
import ai.starwhale.mlops.exception.SWValidationException.ValidSubject;
import ai.starwhale.mlops.storage.StorageAccessService;
import ai.starwhale.mlops.storage.fs.FileStorageEnv;
import ai.starwhale.mlops.storage.fs.FileStorageEnv.FileSystemEnvType;
import ai.starwhale.mlops.storage.s3.S3Config;
import ai.starwhale.mlops.storage.s3.StorageAccessServiceS3;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;

@Component
public class StorageAccessParser {

final StorageAccessService defaultStorageAccessService;

final SWDatasetVersionMapper swDatasetVersionMapper;

ConcurrentHashMap<String, StorageAccessService> storageAccessServicePool = new ConcurrentHashMap<>();

public StorageAccessParser(StorageAccessService defaultStorageAccessService,
SWDatasetVersionMapper swDatasetVersionMapper) {
this.defaultStorageAccessService = defaultStorageAccessService;
this.swDatasetVersionMapper = swDatasetVersionMapper;
}

/**
* @param datasetId
* @param uri
* @param authName allow empty
* @return
*/
public StorageAccessService getStorageAccessServiceFromAuth(Long datasetId, String uri,
String authName) {

StorageAccessService cachedStorageAccessService = storageAccessServicePool.get(
formatKey(datasetId, authName));
if (null != cachedStorageAccessService) {
return cachedStorageAccessService;
}
SWDatasetVersionEntity swDatasetVersionEntity = swDatasetVersionMapper.getVersionById(
datasetId);
String storageAuthsText = swDatasetVersionEntity.getStorageAuths();
if (!StringUtils.hasText(storageAuthsText)) {
return defaultStorageAccessService;
}

StorageAuths storageAuths = new StorageAuths(storageAuthsText);
FileStorageEnv env = storageAuths.getEnv(authName);
if (null == env) {
return defaultStorageAccessService;
}
if (env.getEnvType() != FileSystemEnvType.S3) {
throw new SWValidationException(ValidSubject.SWDS).tip(
"file system not supported yet: " + env.getEnvType());
}
StorageAccessServiceS3 storageAccessServiceS3 = new StorageAccessServiceS3(
env2S3Config(new StorageUri(uri), env, authName));
storageAccessServicePool.putIfAbsent(formatKey(datasetId, authName),
storageAccessServiceS3);
return storageAccessServiceS3;
}

String formatKey(Long datasetId, String authName) {
return datasetId.toString() + authName;
}


static final String KEY_BUCKET = "USER.S3.%sBUCKET";
static final String KEY_REGION = "USER.S3.%sREGION";
static final String KEY_ENDPOINT = "USER.S3.%sENDPOINT";
static final String KEY_SECRET = "USER.S3.%sSECRET";
static final String KEY_ACCESS_KEY = "USER.S3.%sACCESS_KEY";

S3Config env2S3Config(StorageUri storageUri, FileStorageEnv env, String authName) {
if (StringUtils.hasText(authName)) {
authName = authName + ".";
} else {
authName = "";
}
Map<String, String> envs = env.getEnvs();
String bucket = StringUtils.hasText(storageUri.getBucket()) ? storageUri.getBucket()
: envs.get(String.format(KEY_BUCKET, authName));
String accessKey = StringUtils.hasText(storageUri.getUsername()) ? storageUri.getUsername()
: envs.get(String.format(KEY_ACCESS_KEY, authName));
String accessSecret =
StringUtils.hasText(storageUri.getPassword()) ? storageUri.getPassword()
: envs.get(String.format(KEY_SECRET, authName));
String endpoint = StringUtils.hasText(storageUri.getHost()) ? buildEndPoint(storageUri)
: envs.get(String.format(KEY_ENDPOINT, authName));
return new S3Config(bucket
, envs.get(accessKey)
, envs.get(accessSecret)
, envs.get(String.format(KEY_REGION, authName))
, envs.get(endpoint));
}

private String buildEndPoint(StorageUri storageUri) {
if (null == storageUri.getPort() || 80 == storageUri.getPort()) {
return "http://" + storageUri.getHost();
} else if (443 == storageUri.getPort()) {
return "https://" + storageUri.getHost();
} else {
return "http://" + storageUri.getHost() + ":" + storageUri.getPort();
}

}

}
Loading