-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor the CLI to be separate commands
This refactor aims to make the commands more organized and easier to work it. The idea behind this is that we can add new commands and separate the logic for all of them in their own individual modules. Currently, this is a rewrite of what we had in __main__.py, so we have a `query`, `history` and `record` command. In the future, we may have more or less.
- Loading branch information
Showing
20 changed files
with
604 additions
and
145 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import logging | ||
from argparse import Namespace | ||
|
||
from command_line_assistant.config import Config | ||
from command_line_assistant.history import handle_history_write | ||
from command_line_assistant.utils.cli import BaseCLICommand, SubParsersAction | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class HistoryCommand(BaseCLICommand): | ||
def __init__(self, clear: bool, config: Config) -> None: | ||
self._clear = clear | ||
self._config = config | ||
super().__init__() | ||
|
||
def run(self) -> None: | ||
if self._clear: | ||
logger.info("Clearing history of conversation") | ||
handle_history_write(self._config, [], "") | ||
|
||
|
||
def register_subcommand(parser: SubParsersAction, config: Config): | ||
""" | ||
Register this command to argparse so it's available for the datasets-cli | ||
Args: | ||
parser: Root parser to register command-specific arguments | ||
""" | ||
history_parser = parser.add_parser( | ||
"history", | ||
help="Manage conversation history", | ||
) | ||
history_parser.add_argument( | ||
"--clear", action="store_true", help="Clear the history." | ||
) | ||
|
||
# TODO(r0x0d): This is temporary as it will get removed | ||
history_parser.set_defaults(func=lambda args: _command_factory(args, config)) | ||
|
||
|
||
def _command_factory(args: Namespace, config: Config) -> HistoryCommand: | ||
return HistoryCommand(args.clear, config) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from argparse import Namespace | ||
|
||
from command_line_assistant.config import Config | ||
from command_line_assistant.handlers import handle_query | ||
from command_line_assistant.utils.cli import BaseCLICommand, SubParsersAction | ||
|
||
|
||
class QueryCommand(BaseCLICommand): | ||
def __init__(self, query_string: str, config: Config) -> None: | ||
self._query = query_string | ||
self._config = config | ||
super().__init__() | ||
|
||
def run(self) -> None: | ||
handle_query(self._query, self._config) | ||
|
||
|
||
def register_subcommand(parser: SubParsersAction, config: Config) -> None: | ||
""" | ||
Register this command to argparse so it's available for the datasets-cli | ||
Args: | ||
parser: Root parser to register command-specific arguments | ||
""" | ||
query_parser = parser.add_parser( | ||
"query", | ||
help="", | ||
) | ||
# Positional argument, required only if no optional arguments are provided | ||
query_parser.add_argument( | ||
"query_string", nargs="?", help="Query string to be processed." | ||
) | ||
|
||
# TODO(r0x0d): This is temporary as it will get removed | ||
query_parser.set_defaults(func=lambda args: _command_factory(args, config)) | ||
|
||
|
||
def _command_factory(args: Namespace, config: Config) -> QueryCommand: | ||
return QueryCommand(args.query_string, config) | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import logging | ||
import os | ||
import sys | ||
|
||
from command_line_assistant.config import Config | ||
from command_line_assistant.handlers import handle_script_session | ||
from command_line_assistant.utils.cli import BaseCLICommand, SubParsersAction | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
# NOTE: This needs more refinement, script session can't be combined with other arguments | ||
|
||
|
||
class RecordCommand(BaseCLICommand): | ||
def __init__(self, config: Config) -> None: | ||
self._config = config | ||
super().__init__() | ||
|
||
def run(self) -> None: | ||
enforce_script_session = self._config.output.enforce_script | ||
output_file = self._config.output.file | ||
|
||
if enforce_script_session and not os.path.exists(output_file): | ||
logger.error( | ||
"Please call `%s record` first to initialize script session or create the output file.", | ||
sys.argv[0], | ||
) | ||
|
||
handle_script_session(output_file) | ||
|
||
|
||
def register_subcommand(parser: SubParsersAction, config: Config): | ||
""" | ||
Register this command to argparse so it's available for the datasets-cli | ||
Args: | ||
parser: Root parser to register command-specific arguments | ||
""" | ||
record_parser = parser.add_parser( | ||
"record", | ||
help="Start a recording session for script output.", | ||
) | ||
|
||
# TODO(r0x0d): This is temporary as it will get removed | ||
record_parser.set_defaults(func=lambda args: _command_factory(config)) | ||
|
||
|
||
def _command_factory(config: Config) -> RecordCommand: | ||
return RecordCommand(config) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
VERSION = "0.1.0" |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import select | ||
import sys | ||
from abc import ABC, abstractmethod | ||
from argparse import SUPPRESS, ArgumentParser, _SubParsersAction | ||
from typing import Optional | ||
|
||
from command_line_assistant.config import CONFIG_DEFAULT_PATH | ||
from command_line_assistant.constants import VERSION | ||
|
||
# Define the type here so pyright is happy with it. | ||
SubParsersAction = _SubParsersAction | ||
|
||
PARENT_ARGS: list[str] = ["--version", "-v", "-h", "--help"] | ||
ARGS_WITH_VALUES: list[str] = ["--clear"] | ||
|
||
|
||
class BaseCLICommand(ABC): | ||
@abstractmethod | ||
def run(self): | ||
raise NotImplementedError("Not implemented in base class.") | ||
|
||
|
||
def add_default_command(argv): | ||
"""Add the default command when none is given""" | ||
args = argv[1:] | ||
|
||
# Early exit if we don't have any argv | ||
if not args: | ||
return args | ||
|
||
subcommand = _subcommand_used(argv) | ||
if subcommand is None: | ||
args.insert(0, "query") | ||
|
||
return args | ||
|
||
|
||
def _subcommand_used(args): | ||
"""Return what subcommand has been used by the user. Return None if no subcommand has been used.""" | ||
for index, argument in enumerate(args): | ||
# If we have a exact match for any of the commands, return directly | ||
if argument in ("query", "history"): | ||
return argument | ||
|
||
# It means that we hit a --version/--help | ||
if argument in PARENT_ARGS: | ||
return argument | ||
|
||
# Otherwise, check if this is the second part of an arg that takes a value. | ||
elif args[index - 1] in ARGS_WITH_VALUES: | ||
continue | ||
|
||
return None | ||
|
||
|
||
def create_argument_parser() -> tuple[ArgumentParser, SubParsersAction]: | ||
"""Create the argument parser for command line assistant.""" | ||
parser = ArgumentParser( | ||
description="A script with multiple optional arguments and a required positional argument if no optional arguments are provided.", | ||
) | ||
parser.add_argument( | ||
"--version", | ||
action="version", | ||
version=VERSION, | ||
default=SUPPRESS, | ||
help="Show command line assistant version and exit.", | ||
) | ||
parser.add_argument( | ||
"--config", | ||
default=CONFIG_DEFAULT_PATH, | ||
help="Path to the config file.", | ||
) | ||
commands_parser = parser.add_subparsers( | ||
dest="command", help="command line assistant helpers" | ||
) | ||
|
||
return parser, commands_parser | ||
|
||
|
||
def read_stdin() -> Optional[str]: | ||
"""Parse the std input when a user give us. | ||
For example, consider the following scenario: | ||
>>> echo "how to run podman?" | c | ||
Or a more complex one | ||
>>> cat error-log | c "How to fix this?" | ||
Returns: | ||
In case we have a stdin, we parse and retrieve it. Otherwise, just | ||
return None. | ||
""" | ||
# Check if there's input available on stdin | ||
if select.select([sys.stdin], [], [], 0.0)[0]: | ||
# If there is input, read it | ||
input_data = sys.stdin.read().strip() | ||
return input_data | ||
# If no input, return None or handle as you prefer | ||
return None |
Oops, something went wrong.