Skip to content

Commit

Permalink
Last refactoring bundle for 0.4
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalioglu committed Feb 7, 2025
1 parent 7acb38b commit 2564806
Show file tree
Hide file tree
Showing 54 changed files with 620 additions and 198 deletions.
6 changes: 3 additions & 3 deletions src/fairseq2/assets/cards/models/jepa.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
name: jepa_vitl16
model_family: jepa
model_arch: large
model_config_override:
model_config:
encoder_config:
_set_:
input_dims: [16, 224, 224]
Expand All @@ -20,7 +20,7 @@ checkpoint: "https://dl.fbaipublicfiles.com/jepa/vitl16/vitl16.pth.tar"
name: jepa_vith16
model_family: jepa
model_arch: huge
model_config_override:
model_config:
encoder_config:
_set_:
input_dims: [16, 224, 224]
Expand All @@ -33,7 +33,7 @@ checkpoint: "https://dl.fbaipublicfiles.com/jepa/vith16/vith16.pth.tar"
name: jepa_vith16_384
model_family: jepa
model_arch: huge
model_config_override:
model_config:
encoder_config:
_set_:
input_dims: [16, 384, 384]
Expand Down
2 changes: 1 addition & 1 deletion src/fairseq2/assets/cards/models/llama.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ use_v2_tokenizer: true

name: llama3_instruct
base: llama3
model_config_override:
model_config:
vocab_info:
_set_:
eos_idx: 128009 # end-of-turn
Expand Down
2 changes: 1 addition & 1 deletion src/fairseq2/assets/cards/models/s2t_conformer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ tokenizer_family: s2t_transformer
name: s2t_conformer_covost_st_en_de_rel_pos
model_family: s2t_transformer
model_arch: conformer_medium
model_config_override:
model_config:
_set_:
use_relative_pos: true
task: translation
Expand Down
6 changes: 3 additions & 3 deletions src/fairseq2/assets/cards/models/s2t_transformer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
name: s2t_transformer_mustc_asr_de_s
model_family: s2t_transformer
model_arch: small
model_config_override:
model_config:
target_vocab_info:
size: 5000
task: transcription
Expand All @@ -21,7 +21,7 @@ tokenizer_family: s2t_transformer
name: s2t_transformer_mustc_asr_es_s
model_family: s2t_transformer
model_arch: small
model_config_override:
model_config:
target_vocab_info:
_set_:
size: 5000
Expand All @@ -47,7 +47,7 @@ tokenizer_family: s2t_transformer
name: s2t_transformer_mustc_st_de_s
model_family: s2t_transformer
model_arch: small
model_config_override:
model_config:
target_vocab_info:
_set_:
size: 8000
Expand Down
2 changes: 1 addition & 1 deletion src/fairseq2/checkpoint/_metadata_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def save(
metadata: dict[str, object] = {
"name": "checkpoint",
"model_family": model_family,
"model_config_override": {
"model_config": {
"_set_": unstructured_config,
},
}
Expand Down
1 change: 1 addition & 0 deletions src/fairseq2/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from __future__ import annotations

from fairseq2.cli._cli import Cli as Cli
from fairseq2.cli._cli import CliArgumentError as CliArgumentError
from fairseq2.cli._cli import CliCommand as CliCommand
from fairseq2.cli._cli import CliCommandHandler as CliCommandHandler
from fairseq2.cli._cli import CliGroup as CliGroup
Expand Down
16 changes: 16 additions & 0 deletions src/fairseq2/cli/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ def run(self, context: RuntimeContext) -> int:

try:
return args.command.run(context, args) # type: ignore[no-any-return]
except CliArgumentError as ex:
log.error(str(ex), ex=ex.__cause__)

return 2
except ProgramError:
log.exception("Command failed. See logged stack trace for details.")

Expand Down Expand Up @@ -386,3 +390,15 @@ def run(
self, context: RuntimeContext, parser: ArgumentParser, args: Namespace
) -> int:
"""Run the command."""


class CliArgumentError(Exception):
param_name: str | None

def __init__(self, param_name: str | None, message: str) -> None:
if param_name is not None:
message = f"argument: {param_name}: {message}"

super().__init__(message)

self.param_name = param_name
8 changes: 4 additions & 4 deletions src/fairseq2/cli/commands/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
AssetMetadataLoadError,
AssetStore,
)
from fairseq2.cli import CliCommandHandler
from fairseq2.cli import CliArgumentError, CliCommandHandler
from fairseq2.cli.utils.rich import get_console
from fairseq2.context import RuntimeContext
from fairseq2.error import ProgramError
Expand Down Expand Up @@ -187,9 +187,9 @@ def run(
args.name, envs=args.envs, scope=scope
)
except AssetCardNotFoundError:
log.error("argument name: '{}' is not a known asset. Use `fairseq2 assets list` to see the available assets.", args.name) # fmt: skip

return 2
raise CliArgumentError(
"name", f"'{args.name}' is not a known asset. Use `fairseq2 assets list` to see the available assets." # fmt: skip
) from None
except AssetCardError as ex:
raise ProgramError(
f"The '{args.name}' asset card cannot be read. See the nested exception for details."
Expand Down
34 changes: 17 additions & 17 deletions src/fairseq2/cli/commands/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing_extensions import override

from fairseq2.chatbots import Chatbot, ChatbotHandler, ChatMessage, UnknownChatbotError
from fairseq2.cli import CliCommandHandler
from fairseq2.cli import CliArgumentError, CliCommandHandler
from fairseq2.cli.utils.argparse import parse_dtype
from fairseq2.cli.utils.cluster import set_torch_distributed_variables
from fairseq2.cli.utils.rich import get_console
Expand All @@ -37,7 +37,7 @@
setup_gangs,
setup_reference_model,
)
from fairseq2.recipes.config import GangSection, TextTokenizerSection
from fairseq2.recipes.config import GangSection, ReferenceModelSection
from fairseq2.typing import CPU
from fairseq2.utils.rng import RngBag

Expand All @@ -49,7 +49,7 @@ def init_parser(self, parser: ArgumentParser) -> None:
parser.add_argument(
"-m",
"--model",
dest="model",
dest="model_name",
metavar="MODEL_NAME",
default="llama3_1_8b_instruct",
help="instruct model name (default: %(default)s)",
Expand Down Expand Up @@ -110,47 +110,47 @@ def run(
) -> int:
console = get_console()

view = CliChatbotView(args.model, console)
view = CliChatbotView(args.model_name, console)

try:
set_torch_distributed_variables(context, args.cluster)
except UnknownClusterError as ex:
s = ", ".join(ex.supported_clusters)

log.error("argument --cluster: '{}' is not a known cluster. Must be one of: auto, none, {}", ex.cluster, s) # fmt: skip

return 2
raise CliArgumentError(
"cluster", f"'{ex.cluster}' is not a known cluster. Must be one of: auto, none, {s}" # fmt: skip
) from None
except ClusterError as ex:
if ex.cluster == "slurm":
log.exception("'{}' cluster environment cannot be set. See logged stack trace for details. If you are within an allocated Slurm job (i.e. `salloc`), make sure to run with `srun`. If you want to run without Slurm, use `--cluster none`.", ex.cluster) # fmt: skip
message = f"'{ex.cluster}' cluster environment cannot be set. See logged stack trace for details. If you are within an allocated Slurm job (i.e. `salloc`), make sure to run with `srun`. If you want to run without Slurm, use `--cluster none`."
else:
log.exception("'{}' cluster environment cannot be set. See logged stack trace for details.", ex.cluster) # fmt: skip

return 1
message = f"'{ex.cluster}' cluster environment cannot be set. See logged stack trace for details."

torch.set_float32_matmul_precision("high")
raise ProgramError(message) from ex

args.gang = GangSection(
tensor_parallel_size=args.tensor_parallel_size, timeout=999
)

torch.set_float32_matmul_precision("high")

gangs = setup_gangs(context, args)

if gangs.dp.size > 1:
log.warning("Using redundant data parallelism which may reduce token throughput. It is recommended to use one device per model shard (i.e. a single device for a non-sharded model).") # fmt: skip

args.model = ReferenceModelSection(name=args.model_name)

model = setup_reference_model(
DecoderModel,
context,
args.model,
args.model_name,
gangs,
args.dtype,
mp=False,
torch_compile=False,
)

args.text_tokenizer = TextTokenizerSection(name=args.model)

tokenizer = load_text_tokenizer(context, args)

sampler = TopPSampler(p=args.top_p)
Expand All @@ -159,7 +159,7 @@ def run(
model, sampler, temperature=args.temperature, max_gen_len=args.max_gen_len
)

card = context.asset_store.retrieve_card(args.model)
card = context.asset_store.retrieve_card(args.model_name)

family = card.field("model_family").as_(str)

Expand All @@ -168,7 +168,7 @@ def run(
try:
chatbot_handler = chatbot_handlers.get(family)
except LookupError:
raise UnknownChatbotError(args.model) from None
raise UnknownChatbotError(args.model_name) from None

chatbot = chatbot_handler.create(generator, tokenizer)

Expand Down
30 changes: 14 additions & 16 deletions src/fairseq2/cli/commands/llama/_convert_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
AssetCardFieldNotFoundError,
AssetCardNotFoundError,
)
from fairseq2.cli import CliCommandHandler
from fairseq2.cli import CliArgumentError, CliCommandHandler
from fairseq2.cli.utils.rich import get_error_console
from fairseq2.context import RuntimeContext
from fairseq2.error import InternalError, ProgramError
Expand Down Expand Up @@ -77,9 +77,7 @@ def read_error() -> ProgramError:
raise read_error() from ex

if not input_exists:
log.error("argument input_dir: must be a directory")

return 2
raise CliArgumentError("input_dir", "must be a directory")

# Determine input checkpoint files.
input_file = input_dir.joinpath("model.pt")
Expand Down Expand Up @@ -108,9 +106,9 @@ def read_error() -> ProgramError:
input_files.append(input_file)

if not input_files:
log.error("argument input_dir: must contain a model checkpoint file (i.e. model.pt)") # fmt: skip

return 2
raise CliArgumentError(
"input_dir", "must contain a model checkpoint file (i.e. model.pt)"
)

output_dir: Path = args.output_dir

Expand Down Expand Up @@ -146,9 +144,9 @@ def write_error() -> ProgramError:
try:
card = context.asset_store.retrieve_card(args.model)
except AssetCardNotFoundError:
log.error(f"argument model: '{args.model}' is not a known LLaMA model. Use `fairseq2 assets list` to see the available models.") # fmt: skip

return 2
raise CliArgumentError(
"model", f"'{args.model}' is not a known LLaMA model. Use `fairseq2 assets list` to see the available models." # fmt: skip
) from None
except AssetCardError as ex:
raise ProgramError(
f"The '{args.model}' asset card cannot be read. See the nested exception for details."
Expand All @@ -157,18 +155,18 @@ def write_error() -> ProgramError:
try:
family = card.field("model_family").as_(str)
except AssetCardFieldNotFoundError:
log.error(f"argument model: '{args.model}' is not a known LLaMA model. Use `fairseq2 assets list` to see the available models.") # fmt: skip

return 2
raise CliArgumentError(
"model", f"'{args.model}' is not a known LLaMA model. Use `fairseq2 assets list` to see the available models." # fmt: skip
) from None
except AssetCardError as ex:
raise ProgramError(
f"The '{args.model}' asset card cannot be read. See the nested exception for details."
) from ex

if family != LLAMA_MODEL_FAMILY:
log.error(f"argument model: '{args.model}' is not a model of LLaMA family.") # fmt: skip

return 2
raise CliArgumentError(
"model", f"'{args.model}' is not a model of LLaMA family."
)

model_handlers = context.get_registry(ModelHandler)

Expand Down
20 changes: 10 additions & 10 deletions src/fairseq2/cli/commands/llama/_write_hf_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
AssetCardFieldNotFoundError,
AssetCardNotFoundError,
)
from fairseq2.cli import CliCommandHandler
from fairseq2.cli import CliArgumentError, CliCommandHandler
from fairseq2.context import RuntimeContext
from fairseq2.error import InternalError, ProgramError
from fairseq2.logging import log
Expand Down Expand Up @@ -51,9 +51,9 @@ def run(
try:
card = context.asset_store.retrieve_card(args.model)
except AssetCardNotFoundError:
log.error(f"argument model: '{args.model}' is not a known LLaMA model. Use `fairseq2 assets list` to see the available models.") # fmt: skip

return 2
raise CliArgumentError(
"model", f"'{args.model}' is not a known LLaMA model. Use `fairseq2 assets list` to see the available models." # fmt: skip
) from None
except AssetCardError as ex:
raise ProgramError(
f"The '{args.model}' asset card cannot be read. See the nested exception for details."
Expand All @@ -62,18 +62,18 @@ def run(
try:
family = card.field("model_family").as_(str)
except AssetCardFieldNotFoundError:
log.error(f"argument model: '{args.model}' is not a known LLaMA model. Use `fairseq2 assets list` to see the available models.") # fmt: skip

return 2
raise CliArgumentError(
"name", f"'{args.model}' is not a known LLaMA model. Use `fairseq2 assets list` to see the available models." # fmt: skip
) from None
except AssetCardError as ex:
raise ProgramError(
f"The '{args.model}' asset card cannot be read. See the nested exception for details."
) from ex

if family != LLAMA_MODEL_FAMILY:
log.error(f"argument model: '{args.model}' is not a model of LLaMA family.") # fmt: skip

return 2
raise CliArgumentError(
"model", f"'{args.model}' is not a model of LLaMA family." # fmt: skip
)

model_handlers = context.get_registry(ModelHandler)

Expand Down
Loading

0 comments on commit 2564806

Please sign in to comment.