diff --git a/gui/app/components/common.py b/gui/app/components.py similarity index 100% rename from gui/app/components/common.py rename to gui/app/components.py diff --git a/gui/app/components/__init__.py b/gui/app/components/__init__.py deleted file mode 100644 index abe4dd3..0000000 --- a/gui/app/components/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""It contains the implementation various components.""" diff --git a/gui/app/components/dataloader_creation.py b/gui/app/components/dataloader_creation.py deleted file mode 100644 index 485c948..0000000 --- a/gui/app/components/dataloader_creation.py +++ /dev/null @@ -1,303 +0,0 @@ -"""Create dataloader components.""" - -from collections.abc import Callable - -import reflex as rx -from reflex_ag_grid import ag_grid - -from .common import SIDEBAR_OPTIONS, control_buttons, home, select_mode, title - - -def checkboxes(row: list[str], state: rx.State) -> rx.Component: - """Checkbox of parameter value.""" - - def _in_leagues(name: str) -> rx.Var: - return state.default_param_checked['leagues'].contains(name.to_string()) - - def _in_years(name: str) -> rx.Var: - return state.default_param_checked['years'].contains(name.to_string()) - - def _in_divisions(name: str) -> rx.Var: - return state.default_param_checked['divisions'].contains(name.to_string()) - - return rx.vstack( - rx.foreach( - row, - lambda name: rx.checkbox( - name, - default_checked=rx.cond( - _in_leagues(name), True, rx.cond(_in_years(name), True, rx.cond(_in_divisions(name), True, False)) - ), - checked=state.param_checked[name.to_string()], - name=name.to_string(), - on_change=lambda checked: state.update_param_checked(name, checked), - ), - ), - ) - - -def dialog(name: str, icon_name: str, state: rx.State) -> Callable: - """Dialog component.""" - - def _dialog(rows: list[list[str]], on_submit: Callable) -> rx.Component: - """The dialog component.""" - return rx.dialog.root( - rx.dialog.trigger( - rx.button( - rx.tooltip(rx.icon(icon_name), content=name), - size='4', - variant='outline', - disabled=state.visibility_level > 3, - ), - ), - rx.dialog.content( - rx.form.root( - rx.dialog.title(name), - rx.dialog.description( - f'Select the {name.lower()} to include in the training data.', - size="2", - margin_bottom="16px", - ), - rx.hstack(rx.foreach(rows, lambda row: checkboxes(row, state))), - rx.flex( - rx.dialog.close(rx.button('Submit', type='submit')), - justify='end', - spacing="3", - margin_top="50px", - ), - on_submit=on_submit, - reset_on_submit=False, - width="100%", - ), - ), - ) - - return _dialog - - -def training_parameters_selection(state: rx.State) -> rx.Component: - """The training parameters selection component.""" - return rx.vstack( - rx.vstack( - rx.text('Odds type', size='1'), - rx.select( - state.odds_types, - default_value=state.odds_types[0], - on_change=state.handle_odds_type, - disabled=state.visibility_level > 4, - width='100%', - ), - ), - rx.vstack( - rx.text('Drop NA threshold of columns', size='1'), - rx.slider( - min=0.0, - max=1.0, - step=0.01, - default_value=0.0, - on_change=state.handle_drop_na_thres, - disabled=state.visibility_level > 4, - ), - style={ - 'margin-top': '15px', - 'width': '100%', - }, - ), - ) - - -def parameters_selection(state: rx.State) -> rx.Component: - """The parameters title.""" - return rx.hstack( - dialog('Leagues', 'earth', state)(state.all_leagues, state.handle_submit_leagues), - dialog('Years', 'calendar', state)(state.all_years, state.handle_submit_years), - dialog('Divisions', 'gauge', state)(state.all_divisions, state.handle_submit_divisions), - ) - - -def main(state: rx.State) -> rx.Component: - """Main container of UI.""" - return rx.container( - rx.vstack( - home(), - rx.divider(), - # Mode selection - title('Mode', 'blend'), - select_mode(state, 'Create a dataloader'), - # Sport selection - rx.cond( - state.visibility_level > 1, - title('Sport', 'medal'), - ), - rx.cond( - state.visibility_level > 1, - rx.text('Select a sport', size='1'), - ), - rx.cond( - state.visibility_level > 1, - rx.select( - items=['Soccer'], - value='Soccer', - disabled=state.visibility_level > 2, - on_change=state.set_sport_selection, - width='120px', - ), - ), - # Parameters selection - rx.cond( - state.visibility_level > 2, - title('Parameters', 'proportions'), - ), - rx.cond( - state.visibility_level > 2, - rx.text('Select parameters', size='1'), - ), - rx.cond( - state.visibility_level > 2, - parameters_selection(state), - ), - # Training parameters selection - rx.cond( - state.visibility_level > 3, - training_parameters_selection(state), - ), - rx.cond( - state.visibility_level > 4, - rx.button( - 'Save', - position='fixed', - top='620px', - left='275px', - width='70px', - on_click=state.download_dataloader, - ), - ), - # Control - control_buttons(state, state.visibility_level == 5), - **SIDEBAR_OPTIONS, - ), - rx.vstack( - rx.cond( - state.visibility_level == 5, - rx.hstack( - rx.heading( - 'Training data', size='7', position='fixed', left='450px', top='50px', color_scheme='blue' - ) - ), - ), - rx.hstack( - rx.vstack( - rx.cond(state.visibility_level == 5, rx.heading('Input')), - rx.cond( - state.visibility_level == 5, - ag_grid( - id='X_train', - row_data=state.X_train, - column_defs=state.X_train_cols, - height='200px', - width='250px', - theme='balham', - ), - ), - ), - rx.vstack( - rx.cond(state.visibility_level == 5, rx.heading('Output')), - rx.cond( - state.visibility_level == 5, - ag_grid( - id='Y_train', - row_data=state.Y_train, - column_defs=state.Y_train_cols, - height='200px', - width='250px', - theme='balham', - ), - ), - ), - rx.vstack( - rx.cond(state.visibility_level == 5, rx.heading('Odds')), - rx.cond( - state.visibility_level == 5, - ag_grid( - id='O_train', - row_data=state.O_train, - column_defs=state.O_train_cols, - height='200px', - width='250px', - theme='balham', - ), - ), - ), - position='fixed', - left='450px', - top='100px', - ), - ), - rx.vstack( - rx.cond( - state.visibility_level == 5, - rx.hstack( - rx.heading( - 'Fixtures data', size='7', position='fixed', left='450px', top='370px', color_scheme='blue' - ) - ), - ), - rx.cond( - state.visibility_level == 5, - rx.cond( - state.X_fix, - rx.hstack( - rx.vstack( - rx.cond(state.visibility_level == 5, rx.heading('Input')), - rx.cond( - state.visibility_level == 5, - ag_grid( - id='X_fix', - row_data=state.X_fix, - column_defs=state.X_fix_cols, - height='200px', - width='250px', - theme='balham', - ), - ), - ), - rx.vstack( - rx.cond(state.visibility_level == 5, rx.heading('Output')), - rx.cond( - state.visibility_level == 5, - ag_grid( - id='Y_fix', - row_data=[], - column_defs=[], - height='200px', - width='250px', - theme='balham', - ), - ), - ), - rx.vstack( - rx.cond(state.visibility_level == 5, rx.heading('Odds')), - rx.cond( - state.visibility_level == 5, - ag_grid( - id='O_fix', - row_data=state.O_fix, - column_defs=state.O_fix_cols, - height='200px', - width='250px', - theme='balham', - ), - ), - ), - position='fixed', - left='450px', - top='420px', - ), - rx.tooltip( - rx.icon('ban', position='fixed', left='450px', top='420px', size=60), - content='No fixtures were found. Try again later.', - ), - ), - ), - ), - ) diff --git a/gui/app/components/dataloader_loading.py b/gui/app/components/dataloader_loading.py deleted file mode 100644 index c204056..0000000 --- a/gui/app/components/dataloader_loading.py +++ /dev/null @@ -1,122 +0,0 @@ -"""Load dataloader components.""" - -from collections.abc import Callable - -import reflex as rx - -from .common import SIDEBAR_OPTIONS, control_buttons, home, select_mode, title - - -def checkboxes(row: list[str], state: rx.State) -> rx.Component: - """Checkbox of parameter value.""" - - return rx.vstack( - rx.foreach( - row, - lambda name: rx.checkbox( - name, - disabled=True, - default_checked=state.param_checked[name.to_string()], - name=name.to_string(), - ), - ), - ) - - -def dialog(name: str, icon_name: str, state: rx.State) -> Callable: - """Dialog component.""" - - def _dialog(rows: list[list[str]]) -> rx.Component: - """The dialog component.""" - return rx.dialog.root( - rx.dialog.trigger( - rx.button( - rx.tooltip(rx.icon(icon_name), content=name), - size='4', - variant='outline', - disabled=state.visibility_level > 3, - ), - ), - rx.dialog.content( - rx.form.root( - rx.dialog.title(name), - rx.dialog.description( - f'{name} included in the training data.', - size="2", - margin_bottom="16px", - ), - rx.hstack(rx.foreach(rows, lambda row: checkboxes(row, state))), - width="100%", - ), - ), - ) - - return _dialog - - -def main(state: rx.State) -> rx.Component: - """Main container of UI.""" - return rx.container( - rx.vstack( - home(), - rx.divider(), - # Mode selection - title('Mode', 'blend'), - select_mode(state, 'Load a dataloader'), - # Dataloader selection - rx.cond( - state.visibility_level > 1, - title('Dataloader', 'database'), - ), - rx.cond( - state.visibility_level > 1, - rx.upload( - rx.vstack( - rx.button( - 'Select File', - bg='white', - color='rgb(107,99,246)', - border=f'1px solid rgb(107,99,246)', - disabled=state.dataloader_serialized.bool(), - ), - rx.text('Drag and drop', size='2'), - ), - id='dataloader', - multiple=False, - no_keyboard=True, - no_drag=state.dataloader_serialized.bool(), - on_drop=state.handle_upload(rx.upload_files(upload_id='dataloader')), - border='1px dotted blue', - padding='35px', - ), - ), - rx.cond( - state.dataloader_serialized, - rx.text(f'Dataloader: {state.dataloader_filename}', size='1'), - ), - # Parameters presentation - rx.cond( - state.visibility_level > 2, - title('Parameters', 'proportions'), - ), - rx.cond( - state.visibility_level > 2, - rx.hstack( - dialog('Leagues', 'earth', state)(state.all_leagues), - dialog('Years', 'calendar', state)(state.all_years), - dialog('Divisions', 'gauge', state)(state.all_divisions), - ), - ), - rx.cond( - state.visibility_level > 2, - rx.text(f'Odds type: {state.odds_type}', size='1'), - ), - rx.cond( - state.visibility_level > 2, - rx.text(f'Drop NA threshold of columns: {state.drop_na_thres}', size='1'), - ), - # Control - control_buttons(state, (~state.dataloader_serialized.bool()) | (state.visibility_level > 2)), - **SIDEBAR_OPTIONS, - ), - ) diff --git a/gui/app/dataloader_creation.py b/gui/app/dataloader_creation.py index ed5b189..33a45a6 100644 --- a/gui/app/dataloader_creation.py +++ b/gui/app/dataloader_creation.py @@ -1,5 +1,6 @@ """Index page.""" +from collections.abc import Callable from itertools import batched from typing import Any, Self @@ -11,7 +12,7 @@ from sportsbet.datasets import SoccerDataLoader -from .components.dataloader_creation import main +from .components import SIDEBAR_OPTIONS, control_buttons, home, select_mode, title from .index import State DATALOADERS = { @@ -296,7 +297,300 @@ def reset_state(self: Self) -> None: self.O_fix_cols = DEFAULT_STATE_VALS['data']['O_fix_cols'] +def checkboxes(row: list[str], state: rx.State) -> rx.Component: + """Checkbox of parameter value.""" + + def _in_leagues(name: str) -> rx.Var: + return state.default_param_checked['leagues'].contains(name.to_string()) + + def _in_years(name: str) -> rx.Var: + return state.default_param_checked['years'].contains(name.to_string()) + + def _in_divisions(name: str) -> rx.Var: + return state.default_param_checked['divisions'].contains(name.to_string()) + + return rx.vstack( + rx.foreach( + row, + lambda name: rx.checkbox( + name, + default_checked=rx.cond( + _in_leagues(name), True, rx.cond(_in_years(name), True, rx.cond(_in_divisions(name), True, False)) + ), + checked=state.param_checked[name.to_string()], + name=name.to_string(), + on_change=lambda checked: state.update_param_checked(name, checked), + ), + ), + ) + + +def dialog(name: str, icon_name: str, state: rx.State) -> Callable: + """Dialog component.""" + + def _dialog(rows: list[list[str]], on_submit: Callable) -> rx.Component: + """The dialog component.""" + return rx.dialog.root( + rx.dialog.trigger( + rx.button( + rx.tooltip(rx.icon(icon_name), content=name), + size='4', + variant='outline', + disabled=state.visibility_level > 3, + ), + ), + rx.dialog.content( + rx.form.root( + rx.dialog.title(name), + rx.dialog.description( + f'Select the {name.lower()} to include in the training data.', + size="2", + margin_bottom="16px", + ), + rx.hstack(rx.foreach(rows, lambda row: checkboxes(row, state))), + rx.flex( + rx.dialog.close(rx.button('Submit', type='submit')), + justify='end', + spacing="3", + margin_top="50px", + ), + on_submit=on_submit, + reset_on_submit=False, + width="100%", + ), + ), + ) + + return _dialog + + +def training_parameters_selection(state: rx.State) -> rx.Component: + """The training parameters selection component.""" + return rx.vstack( + rx.vstack( + rx.text('Odds type', size='1'), + rx.select( + state.odds_types, + default_value=state.odds_types[0], + on_change=state.handle_odds_type, + disabled=state.visibility_level > 4, + width='100%', + ), + style={ + 'margin-top': '5px', + }, + ), + rx.vstack( + rx.text(f'Drop NA threshold of columns: {DataloaderCreationState.drop_na_thres}', size='1'), + rx.slider( + min=0.0, + max=1.0, + step=0.01, + default_value=0.0, + on_change=state.handle_drop_na_thres, + disabled=state.visibility_level > 4, + width='200px', + ), + style={ + 'margin-top': '5px', + }, + ), + ) + + +def parameters_selection(state: rx.State) -> rx.Component: + """The parameters title.""" + return rx.hstack( + dialog('Leagues', 'earth', state)(state.all_leagues, state.handle_submit_leagues), + dialog('Years', 'calendar', state)(state.all_years, state.handle_submit_years), + dialog('Divisions', 'gauge', state)(state.all_divisions, state.handle_submit_divisions), + ) + + @rx.page(route="/dataloader/creation") def dataloader_creation_page() -> rx.Component: """Main page.""" - return main(DataloaderCreationState) + return rx.container( + rx.vstack( + home(), + rx.divider(), + # Mode selection + title('Mode', 'blend'), + select_mode(DataloaderCreationState, 'Create a dataloader'), + # Sport selection + rx.cond( + DataloaderCreationState.visibility_level > 1, + title('Sport', 'medal'), + ), + rx.cond( + DataloaderCreationState.visibility_level > 1, + rx.text('Select a sport', size='1'), + ), + rx.cond( + DataloaderCreationState.visibility_level > 1, + rx.select( + items=['Soccer'], + value='Soccer', + disabled=DataloaderCreationState.visibility_level > 2, + on_change=DataloaderCreationState.set_sport_selection, + width='120px', + ), + ), + # Parameters selection + rx.cond( + DataloaderCreationState.visibility_level > 2, + title('Parameters', 'proportions'), + ), + rx.cond( + DataloaderCreationState.visibility_level > 2, + rx.text('Select parameters', size='1'), + ), + rx.cond( + DataloaderCreationState.visibility_level > 2, + parameters_selection(DataloaderCreationState), + ), + # Training parameters selection + rx.cond( + DataloaderCreationState.visibility_level > 3, + training_parameters_selection(DataloaderCreationState), + ), + rx.cond( + DataloaderCreationState.visibility_level > 4, + rx.button( + 'Save', + position='fixed', + top='620px', + left='275px', + width='70px', + on_click=DataloaderCreationState.download_dataloader, + ), + ), + # Control + control_buttons(DataloaderCreationState, DataloaderCreationState.visibility_level == 5), + **SIDEBAR_OPTIONS, + ), + rx.vstack( + rx.cond( + DataloaderCreationState.visibility_level == 5, + rx.hstack( + rx.heading( + 'Training data', size='7', position='fixed', left='450px', top='50px', color_scheme='blue' + ) + ), + ), + rx.hstack( + rx.vstack( + rx.cond(DataloaderCreationState.visibility_level == 5, rx.heading('Input')), + rx.cond( + DataloaderCreationState.visibility_level == 5, + ag_grid( + id='X_train', + row_data=DataloaderCreationState.X_train, + column_defs=DataloaderCreationState.X_train_cols, + height='200px', + width='250px', + theme='balham', + ), + ), + ), + rx.vstack( + rx.cond(DataloaderCreationState.visibility_level == 5, rx.heading('Output')), + rx.cond( + DataloaderCreationState.visibility_level == 5, + ag_grid( + id='Y_train', + row_data=DataloaderCreationState.Y_train, + column_defs=DataloaderCreationState.Y_train_cols, + height='200px', + width='250px', + theme='balham', + ), + ), + ), + rx.vstack( + rx.cond(DataloaderCreationState.visibility_level == 5, rx.heading('Odds')), + rx.cond( + DataloaderCreationState.visibility_level == 5, + ag_grid( + id='O_train', + row_data=DataloaderCreationState.O_train, + column_defs=DataloaderCreationState.O_train_cols, + height='200px', + width='250px', + theme='balham', + ), + ), + ), + position='fixed', + left='450px', + top='100px', + ), + ), + rx.vstack( + rx.cond( + DataloaderCreationState.visibility_level == 5, + rx.hstack( + rx.heading( + 'Fixtures data', size='7', position='fixed', left='450px', top='370px', color_scheme='blue' + ) + ), + ), + rx.cond( + DataloaderCreationState.visibility_level == 5, + rx.cond( + DataloaderCreationState.X_fix, + rx.hstack( + rx.vstack( + rx.cond(DataloaderCreationState.visibility_level == 5, rx.heading('Input')), + rx.cond( + DataloaderCreationState.visibility_level == 5, + ag_grid( + id='X_fix', + row_data=DataloaderCreationState.X_fix, + column_defs=DataloaderCreationState.X_fix_cols, + height='200px', + width='250px', + theme='balham', + ), + ), + ), + rx.vstack( + rx.cond(DataloaderCreationState.visibility_level == 5, rx.heading('Output')), + rx.cond( + DataloaderCreationState.visibility_level == 5, + ag_grid( + id='Y_fix', + row_data=[], + column_defs=[], + height='200px', + width='250px', + theme='balham', + ), + ), + ), + rx.vstack( + rx.cond(DataloaderCreationState.visibility_level == 5, rx.heading('Odds')), + rx.cond( + DataloaderCreationState.visibility_level == 5, + ag_grid( + id='O_fix', + row_data=DataloaderCreationState.O_fix, + column_defs=DataloaderCreationState.O_fix_cols, + height='200px', + width='250px', + theme='balham', + ), + ), + ), + position='fixed', + left='450px', + top='420px', + ), + rx.tooltip( + rx.icon('ban', position='fixed', left='450px', top='420px', size=60), + content='No fixtures were found. Try again later.', + ), + ), + ), + ), + ) diff --git a/gui/app/dataloader_loading.py b/gui/app/dataloader_loading.py index f937fc4..242fa1b 100644 --- a/gui/app/dataloader_loading.py +++ b/gui/app/dataloader_loading.py @@ -1,5 +1,6 @@ """Index page.""" +from collections.abc import Callable from itertools import batched from pathlib import Path from typing import Self @@ -7,21 +8,17 @@ import cloudpickle import nest_asyncio import reflex as rx +from reflex.event import EventSpec +from reflex_ag_grid import ag_grid from sportsbet.datasets import SoccerDataLoader -from .components.dataloader_loading import main +from .components import SIDEBAR_OPTIONS, control_buttons, home, select_mode, title from .index import State DATALOADERS = { 'Soccer': SoccerDataLoader, } -DEFAULT_STATE_VALS = { - 'mode': { - 'category': 'Data', - 'type': 'Create', - }, -} nest_asyncio.apply() @@ -38,6 +35,16 @@ class DataloaderLoadingState(State): param_checked: dict[str, bool] = {} odds_type: str | None = None drop_na_thres: float | None = None + X_train: list | None = None + Y_train: list | None = None + O_train: list | None = None + X_train_cols: list | None = None + Y_train_cols: list | None = None + O_train_cols: list | None = None + X_fix: list | None = None + O_fix: list | None = None + X_fix_cols: list | None = None + O_fix_cols: list | None = None @rx.event async def handle_upload(self: Self, files: list[rx.UploadFile]) -> None: @@ -51,6 +58,17 @@ async def handle_upload(self: Self, files: list[rx.UploadFile]) -> None: self.loading = False yield + @rx.event + def download_dataloader(self: Self) -> EventSpec: + """Download the dataloader.""" + dataloader = bytes(self.dataloader_serialized, 'iso8859_16') + return rx.download(data=dataloader, filename='dataloader.pkl') + + @staticmethod + def process_cols(col: str) -> str: + """Proces a column.""" + return " ".join([" ".join(token.split('_')).title() for token in col.split('__')]) + def submit_state(self: Self) -> None: """Submit handler.""" self.loading = True @@ -60,11 +78,39 @@ def submit_state(self: Self) -> None: yield elif self.visibility_level == 2: dataloader = cloudpickle.loads(bytes(self.dataloader_serialized, 'iso8859_16')) + X_train, Y_train, O_train = dataloader.extract_train_data( + odds_type=dataloader.odds_type_, + drop_na_thres=dataloader.drop_na_thres_, + ) + X_fix, _, O_fix = dataloader.extract_fixtures_data() + self.X_train = X_train.reset_index().to_dict('records') + self.X_train_cols = [ag_grid.column_def(field='date', header_name='Date')] + [ + ag_grid.column_def(field=col, header_name=self.process_cols(col)) for col in X_train.columns + ] + self.Y_train = Y_train.to_dict('records') + self.Y_train_cols = [ + ag_grid.column_def(field=col, header_name=self.process_cols(col)) for col in Y_train.columns + ] + self.O_train = O_train.to_dict('records') if O_train is not None else None + self.O_train_cols = ( + [ag_grid.column_def(field=col, header_name=self.process_cols(col)) for col in O_train.columns] + if O_train is not None + else None + ) + self.X_fix = X_fix.reset_index().to_dict('records') + self.X_fix_cols = [ag_grid.column_def(field='date', header_name='Date')] + [ + ag_grid.column_def(field=col, header_name=self.process_cols(col)) for col in X_fix.columns + ] + self.O_fix = O_fix.to_dict('records') if O_fix is not None else None + self.O_fix_cols = ( + [ag_grid.column_def(field=col, header_name=self.process_cols(col)) for col in O_fix.columns] + if O_fix is not None + else None + ) all_params = dataloader.get_all_params() self.all_leagues = list(batched(sorted({params['league'] for params in all_params}), 6)) self.all_years = list(batched(sorted({params['year'] for params in all_params}), 5)) self.all_divisions = list(batched(sorted({params['division'] for params in all_params}), 1)) - self.loading = False self.param_checked = { **{f'"{key}"': True for key in {params['league'] for params in dataloader.param_grid_}}, **{key: True for key in {params['year'] for params in dataloader.param_grid_}}, @@ -72,6 +118,7 @@ def submit_state(self: Self) -> None: } self.odds_type = dataloader.odds_type_ self.drop_na_thres = dataloader.drop_na_thres_ + self.loading = False yield self.visibility_level += 1 @@ -95,9 +142,267 @@ def reset_state(self: Self) -> None: self.param_checked = {} self.odds_type = None self.drop_na_thres = None + self.X_train = None + self.Y_train = None + self.O_train = None + self.X_train_cols = None + self.Y_train_cols = None + self.O_train_cols = None + self.X_fix = None + self.O_fix = None + self.X_fix_cols = None + self.O_fix_cols = None + + +def checkboxes(row: list[str], state: rx.State) -> rx.Component: + """Checkbox of parameter value.""" + + return rx.vstack( + rx.foreach( + row, + lambda name: rx.checkbox( + name, + disabled=True, + default_checked=state.param_checked[name.to_string()], + name=name.to_string(), + ), + ), + ) + + +def dialog(name: str, icon_name: str, state: rx.State) -> Callable: + """Dialog component.""" + + def _dialog(rows: list[list[str]]) -> rx.Component: + """The dialog component.""" + return rx.dialog.root( + rx.dialog.trigger( + rx.button( + rx.tooltip(rx.icon(icon_name), content=name), + size='4', + variant='outline', + disabled=state.visibility_level > 3, + ), + ), + rx.dialog.content( + rx.form.root( + rx.dialog.title(name), + rx.dialog.description( + f'{name} included in the training data.', + size="2", + margin_bottom="16px", + ), + rx.hstack(rx.foreach(rows, lambda row: checkboxes(row, state))), + width="100%", + ), + ), + ) + + return _dialog @rx.page(route="/dataloader/loading") def dataloader_loading_page() -> rx.Component: """Main page.""" - return main(DataloaderLoadingState) + return rx.container( + rx.vstack( + home(), + rx.divider(), + # Mode selection + title('Mode', 'blend'), + select_mode(DataloaderLoadingState, 'Load a dataloader'), + # Dataloader selection + rx.cond( + DataloaderLoadingState.visibility_level > 1, + title('Dataloader', 'database'), + ), + rx.cond( + DataloaderLoadingState.visibility_level > 1, + rx.upload( + rx.vstack( + rx.button( + 'Select File', + bg='white', + color='rgb(107,99,246)', + border=f'1px solid rgb(107,99,246)', + disabled=DataloaderLoadingState.dataloader_serialized.bool(), + ), + rx.text('Drag and drop', size='2'), + ), + id='dataloader', + multiple=False, + no_keyboard=True, + no_drag=DataloaderLoadingState.dataloader_serialized.bool(), + on_drop=DataloaderLoadingState.handle_upload(rx.upload_files(upload_id='dataloader')), + border='1px dotted blue', + padding='35px', + ), + ), + rx.cond( + DataloaderLoadingState.dataloader_serialized, + rx.text(f'Dataloader: {DataloaderLoadingState.dataloader_filename}', size='1'), + ), + # Parameters presentation + rx.cond( + DataloaderLoadingState.visibility_level > 2, + title('Parameters', 'proportions'), + ), + rx.cond( + DataloaderLoadingState.visibility_level > 2, + rx.hstack( + dialog('Leagues', 'earth', DataloaderLoadingState)(DataloaderLoadingState.all_leagues), + dialog('Years', 'calendar', DataloaderLoadingState)(DataloaderLoadingState.all_years), + dialog('Divisions', 'gauge', DataloaderLoadingState)(DataloaderLoadingState.all_divisions), + ), + ), + rx.cond( + DataloaderLoadingState.visibility_level > 2, + rx.text(f'Odds type: {DataloaderLoadingState.odds_type}', size='1'), + ), + rx.cond( + DataloaderLoadingState.visibility_level > 2, + rx.text(f'Drop NA threshold of columns: {DataloaderLoadingState.drop_na_thres}', size='1'), + ), + rx.cond( + DataloaderLoadingState.visibility_level > 2, + rx.button( + 'Save', + position='fixed', + top='620px', + left='275px', + width='70px', + on_click=DataloaderLoadingState.download_dataloader, + ), + ), + # Control + control_buttons( + DataloaderLoadingState, + (~DataloaderLoadingState.dataloader_serialized.bool()) | (DataloaderLoadingState.visibility_level > 2), + ), + **SIDEBAR_OPTIONS, + ), + rx.vstack( + rx.cond( + DataloaderLoadingState.visibility_level == 3, + rx.hstack( + rx.heading( + 'Training data', size='7', position='fixed', left='450px', top='50px', color_scheme='blue' + ) + ), + ), + rx.hstack( + rx.vstack( + rx.cond(DataloaderLoadingState.visibility_level == 3, rx.heading('Input')), + rx.cond( + DataloaderLoadingState.visibility_level == 3, + ag_grid( + id='X_train', + row_data=DataloaderLoadingState.X_train, + column_defs=DataloaderLoadingState.X_train_cols, + height='200px', + width='250px', + theme='balham', + ), + ), + ), + rx.vstack( + rx.cond(DataloaderLoadingState.visibility_level == 3, rx.heading('Output')), + rx.cond( + DataloaderLoadingState.visibility_level == 3, + ag_grid( + id='Y_train', + row_data=DataloaderLoadingState.Y_train, + column_defs=DataloaderLoadingState.Y_train_cols, + height='200px', + width='250px', + theme='balham', + ), + ), + ), + rx.vstack( + rx.cond(DataloaderLoadingState.visibility_level == 3, rx.heading('Odds')), + rx.cond( + DataloaderLoadingState.visibility_level == 3, + ag_grid( + id='O_train', + row_data=DataloaderLoadingState.O_train, + column_defs=DataloaderLoadingState.O_train_cols, + height='200px', + width='250px', + theme='balham', + ), + ), + ), + position='fixed', + left='450px', + top='100px', + ), + ), + rx.vstack( + rx.cond( + DataloaderLoadingState.visibility_level == 3, + rx.hstack( + rx.heading( + 'Fixtures data', size='7', position='fixed', left='450px', top='370px', color_scheme='blue' + ) + ), + ), + rx.cond( + DataloaderLoadingState.visibility_level == 3, + rx.cond( + DataloaderLoadingState.X_fix, + rx.hstack( + rx.vstack( + rx.cond(DataloaderLoadingState.visibility_level == 3, rx.heading('Input')), + rx.cond( + DataloaderLoadingState.visibility_level == 3, + ag_grid( + id='X_fix', + row_data=DataloaderLoadingState.X_fix, + column_defs=DataloaderLoadingState.X_fix_cols, + height='200px', + width='250px', + theme='balham', + ), + ), + ), + rx.vstack( + rx.cond(DataloaderLoadingState.visibility_level == 3, rx.heading('Output')), + rx.cond( + DataloaderLoadingState.visibility_level == 3, + ag_grid( + id='Y_fix', + row_data=[], + column_defs=[], + height='200px', + width='250px', + theme='balham', + ), + ), + ), + rx.vstack( + rx.cond(DataloaderLoadingState.visibility_level == 3, rx.heading('Odds')), + rx.cond( + DataloaderLoadingState.visibility_level == 3, + ag_grid( + id='O_fix', + row_data=DataloaderLoadingState.O_fix, + column_defs=DataloaderLoadingState.O_fix_cols, + height='200px', + width='250px', + theme='balham', + ), + ), + ), + position='fixed', + left='450px', + top='420px', + ), + rx.tooltip( + rx.icon('ban', position='fixed', left='450px', top='420px', size=60), + content='No fixtures were found. Try again later.', + ), + ), + ), + ), + ) diff --git a/gui/app/index.py b/gui/app/index.py index caacf00..be6c774 100644 --- a/gui/app/index.py +++ b/gui/app/index.py @@ -4,7 +4,7 @@ import reflex as rx -from .components.common import SIDEBAR_OPTIONS, home, title +from .components import SIDEBAR_OPTIONS, home, title class State(rx.State):