Skip to content

Commit

Permalink
Merge pull request #208 from transformerlab/sanjay-fix-recipes-that-n…
Browse files Browse the repository at this point in the history
…eed-downloading

Update the error message and automatically download a dataset when a part of a recipe is not found
  • Loading branch information
dadmobile authored Jan 16, 2025
2 parents 331b915 + ee0e255 commit e51b503
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 8 deletions.
51 changes: 43 additions & 8 deletions src/renderer/components/Experiment/Train/ImportRecipeModal.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -109,16 +109,51 @@ export default function ImportRecipeModal({ open, setOpen, mutate }) {
} else if (!response.dataset || ! response.dataset.path) {
alert("Warning: This recipe does not have an associated dataset")
} else {
let msg = "";
if (!response.model.downloaded) {
msg += "Download model " + response.model.path
}
let msg = "Warning: To use this recipe you will need to download the following:";
let shouldDownload = false;

if (!response.dataset.downloaded) {
msg += "Download dataset " + response.dataset.path
msg += "\n- Dataset: " + response.dataset.path;
shouldDownload = true;
}
if (!response.model.downloaded) {
msg += "\n- Model: " + response.model.path;
shouldDownload = true;
}
if (msg) {
const alert_msg = "Warning: To use this recipe you will need to: " + msg
alert(alert_msg);

if (shouldDownload) {
msg += "\n\nDo you want to download these now?";
if (confirm(msg)) { // Use confirm() to get Accept/Cancel
if (!response.dataset.downloaded) {
fetch(chatAPI.Endpoints.Dataset.Download(response.dataset.path))
.then((response) => {
if (!response.ok) {
console.log(response);
throw new Error(`HTTP Status: ${response.status}`);
}
return response.json();
})
.catch((error) => {
alert('Dataset download failed:\n' + error);
});
}
if (!response.model.downloaded) {
chatAPI.downloadModelFromHuggingFace(response.model.path)
.then((response) => {
if (response.status == "error") {
console.log(response);
throw new Error(`${response.message}`);
}
return response;
})
.catch((error) => {
alert('Model download failed:\n' + error);
});
}
} else {
// User pressed Cancel
alert("Downloads cancelled. This recipe might not work correctly.");
}
}
}
}
Expand Down
14 changes: 14 additions & 0 deletions src/renderer/components/Experiment/Train/TrainLoRA.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ import ImportRecipeModal from './ImportRecipeModal';
import dayjs from 'dayjs';
import relativeTime from 'dayjs/plugin/relativeTime';
import ViewOutputModalStreaming from './ViewOutputModalStreaming';
import CurrentDownloadBox from 'renderer/components/currentDownloadBox';
import DownloadProgressBox from 'renderer/components/Shared/DownloadProgressBox';
dayjs.extend(relativeTime);
var duration = require('dayjs/plugin/duration');
dayjs.extend(duration);
Expand Down Expand Up @@ -119,6 +121,15 @@ export default function TrainLoRA({ experimentInfo }) {
refreshInterval: 2000,
});

const {
data: downloadJobs,
error: downloadJobsError,
isLoading: downloadJobsIsLoading,
mutate: downloadJobsMutate,
} = useSWR(chatAPI.Endpoints.Jobs.GetJobsOfType('DOWNLOAD_MODEL', 'RUNNING'), fetcher, {
refreshInterval: 2000,
});

//Fetch available training plugins
const {
data: pluginsData,
Expand Down Expand Up @@ -174,6 +185,9 @@ export default function TrainLoRA({ experimentInfo }) {
overflow: 'hidden',
}}
>
{ !downloadJobsIsLoading &&
<DownloadProgressBox jobId={downloadJobs[0]?.id} assetName={downloadJobs[0]?.job_data.model}/>
}
{/* <Typography level="h1">Train</Typography> */}
<Stack direction="row" justifyContent="space-between" gap={2}>
<Typography level="title-md" startDecorator={<GraduationCapIcon />}>
Expand Down
1 change: 1 addition & 0 deletions src/renderer/components/Shared/DownloadProgressBox.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import {
} from '@mui/joy';

import {
clamp,
formatBytes,
} from '../../lib/utils';
import * as chatAPI from '../../lib/transformerlab-api-sdk';
Expand Down

0 comments on commit e51b503

Please sign in to comment.