diff --git a/docs/source/en/internal/generation_utils.md b/docs/source/en/internal/generation_utils.md index f5c882bf128afb..906ee4ea620b2a 100644 --- a/docs/source/en/internal/generation_utils.md +++ b/docs/source/en/internal/generation_utils.md @@ -75,39 +75,104 @@ values. Here, for instance, it has two keys that are `sequences` and `scores`. We document here all output types. -### GreedySearchOutput - -[[autodoc]] generation.GreedySearchDecoderOnlyOutput +### PyTorch [[autodoc]] generation.GreedySearchEncoderDecoderOutput -[[autodoc]] generation.FlaxGreedySearchOutput +[[autodoc]] generation.GreedySearchDecoderOnlyOutput -### SampleOutput +[[autodoc]] generation.SampleEncoderDecoderOutput [[autodoc]] generation.SampleDecoderOnlyOutput -[[autodoc]] generation.SampleEncoderDecoderOutput +[[autodoc]] generation.BeamSearchEncoderDecoderOutput -[[autodoc]] generation.FlaxSampleOutput +[[autodoc]] generation.BeamSearchDecoderOnlyOutput -### BeamSearchOutput +[[autodoc]] generation.BeamSampleEncoderDecoderOutput -[[autodoc]] generation.BeamSearchDecoderOnlyOutput +[[autodoc]] generation.BeamSampleDecoderOnlyOutput -[[autodoc]] generation.BeamSearchEncoderDecoderOutput +[[autodoc]] generation.ContrastiveSearchEncoderDecoderOutput -### BeamSampleOutput +[[autodoc]] generation.ContrastiveSearchDecoderOnlyOutput -[[autodoc]] generation.BeamSampleDecoderOnlyOutput +### TensorFlow -[[autodoc]] generation.BeamSampleEncoderDecoderOutput +[[autodoc]] generation.TFGreedySearchEncoderDecoderOutput + +[[autodoc]] generation.TFGreedySearchDecoderOnlyOutput + +[[autodoc]] generation.TFSampleEncoderDecoderOutput + +[[autodoc]] generation.TFSampleDecoderOnlyOutput + +[[autodoc]] generation.TFBeamSearchEncoderDecoderOutput + +[[autodoc]] generation.TFBeamSearchDecoderOnlyOutput + +[[autodoc]] generation.TFBeamSampleEncoderDecoderOutput + +[[autodoc]] generation.TFBeamSampleDecoderOnlyOutput + +[[autodoc]] generation.TFContrastiveSearchEncoderDecoderOutput + +[[autodoc]] generation.TFContrastiveSearchDecoderOnlyOutput + +### FLAX + +[[autodoc]] generation.FlaxSampleOutput + +[[autodoc]] generation.FlaxGreedySearchOutput + +[[autodoc]] generation.FlaxBeamSearchOutput ## LogitsProcessor A [`LogitsProcessor`] can be used to modify the prediction scores of a language model head for generation. +### PyTorch + +[[autodoc]] AlternatingCodebooksLogitsProcessor + - __call__ + +[[autodoc]] ClassifierFreeGuidanceLogitsProcessor + - __call__ + +[[autodoc]] EncoderNoRepeatNGramLogitsProcessor + - __call__ + +[[autodoc]] EncoderRepetitionPenaltyLogitsProcessor + - __call__ + +[[autodoc]] EpsilonLogitsWarper + - __call__ + +[[autodoc]] EtaLogitsWarper + - __call__ + +[[autodoc]] ExponentialDecayLengthPenalty + - __call__ + +[[autodoc]] ForcedBOSTokenLogitsProcessor + - __call__ + +[[autodoc]] ForcedEOSTokenLogitsProcessor + - __call__ + +[[autodoc]] ForceTokensLogitsProcessor + - __call__ + +[[autodoc]] HammingDiversityLogitsProcessor + - __call__ + +[[autodoc]] InfNanRemoveLogitsProcessor + - __call__ + +[[autodoc]] LogitNormalization + - __call__ + [[autodoc]] LogitsProcessor - __call__ @@ -123,61 +188,63 @@ generation. [[autodoc]] MinNewTokensLengthLogitsProcessor - __call__ -[[autodoc]] TemperatureLogitsWarper +[[autodoc]] NoBadWordsLogitsProcessor - __call__ -[[autodoc]] RepetitionPenaltyLogitsProcessor +[[autodoc]] NoRepeatNGramLogitsProcessor - __call__ -[[autodoc]] TopPLogitsWarper +[[autodoc]] PrefixConstrainedLogitsProcessor - __call__ -[[autodoc]] TopKLogitsWarper +[[autodoc]] RepetitionPenaltyLogitsProcessor - __call__ -[[autodoc]] TypicalLogitsWarper +[[autodoc]] SequenceBiasLogitsProcessor - __call__ -[[autodoc]] NoRepeatNGramLogitsProcessor +[[autodoc]] SuppressTokensAtBeginLogitsProcessor - __call__ -[[autodoc]] SequenceBiasLogitsProcessor +[[autodoc]] SuppressTokensLogitsProcessor - __call__ -[[autodoc]] NoBadWordsLogitsProcessor +[[autodoc]] TemperatureLogitsWarper - __call__ -[[autodoc]] PrefixConstrainedLogitsProcessor +[[autodoc]] TopKLogitsWarper - __call__ -[[autodoc]] HammingDiversityLogitsProcessor +[[autodoc]] TopPLogitsWarper - __call__ -[[autodoc]] ForcedBOSTokenLogitsProcessor +[[autodoc]] TypicalLogitsWarper - __call__ -[[autodoc]] ForcedEOSTokenLogitsProcessor +[[autodoc]] UnbatchedClassifierFreeGuidanceLogitsProcessor - __call__ -[[autodoc]] InfNanRemoveLogitsProcessor +[[autodoc]] WhisperTimeStampLogitsProcessor - __call__ -[[autodoc]] TFLogitsProcessor +### TensorFlow + +[[autodoc]] TFForcedBOSTokenLogitsProcessor - __call__ -[[autodoc]] TFLogitsProcessorList +[[autodoc]] TFForcedEOSTokenLogitsProcessor - __call__ -[[autodoc]] TFLogitsWarper +[[autodoc]] TFForceTokensLogitsProcessor - __call__ -[[autodoc]] TFTemperatureLogitsWarper +[[autodoc]] TFLogitsProcessor - __call__ -[[autodoc]] TFTopPLogitsWarper +[[autodoc]] TFLogitsProcessorList - __call__ -[[autodoc]] TFTopKLogitsWarper +[[autodoc]] TFLogitsWarper - __call__ [[autodoc]] TFMinLengthLogitsProcessor @@ -192,10 +259,30 @@ generation. [[autodoc]] TFRepetitionPenaltyLogitsProcessor - __call__ -[[autodoc]] TFForcedBOSTokenLogitsProcessor +[[autodoc]] TFSuppressTokensAtBeginLogitsProcessor - __call__ -[[autodoc]] TFForcedEOSTokenLogitsProcessor +[[autodoc]] TFSuppressTokensLogitsProcessor + - __call__ + +[[autodoc]] TFTemperatureLogitsWarper + - __call__ + +[[autodoc]] TFTopKLogitsWarper + - __call__ + +[[autodoc]] TFTopPLogitsWarper + - __call__ + +### FLAX + +[[autodoc]] FlaxForcedBOSTokenLogitsProcessor + - __call__ + +[[autodoc]] FlaxForcedEOSTokenLogitsProcessor + - __call__ + +[[autodoc]] FlaxForceTokensLogitsProcessor - __call__ [[autodoc]] FlaxLogitsProcessor @@ -207,27 +294,30 @@ generation. [[autodoc]] FlaxLogitsWarper - __call__ -[[autodoc]] FlaxTemperatureLogitsWarper +[[autodoc]] FlaxMinLengthLogitsProcessor - __call__ -[[autodoc]] FlaxTopPLogitsWarper +[[autodoc]] FlaxSuppressTokensAtBeginLogitsProcessor - __call__ -[[autodoc]] FlaxTopKLogitsWarper +[[autodoc]] FlaxSuppressTokensLogitsProcessor - __call__ -[[autodoc]] FlaxForcedBOSTokenLogitsProcessor +[[autodoc]] FlaxTemperatureLogitsWarper - __call__ -[[autodoc]] FlaxForcedEOSTokenLogitsProcessor +[[autodoc]] FlaxTopKLogitsWarper - __call__ -[[autodoc]] FlaxMinLengthLogitsProcessor +[[autodoc]] FlaxTopPLogitsWarper + - __call__ + +[[autodoc]] FlaxWhisperTimeStampLogitsProcessor - __call__ ## StoppingCriteria -A [`StoppingCriteria`] can be used to change when to stop generation (other than EOS token). +A [`StoppingCriteria`] can be used to change when to stop generation (other than EOS token). Please note that this is exclusivelly available to our PyTorch implementations. [[autodoc]] StoppingCriteria - __call__ @@ -243,7 +333,7 @@ A [`StoppingCriteria`] can be used to change when to stop generation (other than ## Constraints -A [`Constraint`] can be used to force the generation to include specific tokens or sequences in the output. +A [`Constraint`] can be used to force the generation to include specific tokens or sequences in the output. Please note that this is exclusivelly available to our PyTorch implementations. [[autodoc]] Constraint diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 9b95aadffccc6f..2a05f767650869 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1005,17 +1005,26 @@ _import_structure["deepspeed"] = [] _import_structure["generation"].extend( [ + "AlternatingCodebooksLogitsProcessor", "BeamScorer", "BeamSearchScorer", + "ClassifierFreeGuidanceLogitsProcessor", "ConstrainedBeamSearchScorer", "Constraint", "ConstraintListState", "DisjunctiveConstraint", + "EncoderNoRepeatNGramLogitsProcessor", + "EncoderRepetitionPenaltyLogitsProcessor", + "EpsilonLogitsWarper", + "EtaLogitsWarper", + "ExponentialDecayLengthPenalty", "ForcedBOSTokenLogitsProcessor", "ForcedEOSTokenLogitsProcessor", + "ForceTokensLogitsProcessor", "GenerationMixin", "HammingDiversityLogitsProcessor", "InfNanRemoveLogitsProcessor", + "LogitNormalization", "LogitsProcessor", "LogitsProcessorList", "LogitsWarper", @@ -1031,10 +1040,14 @@ "SequenceBiasLogitsProcessor", "StoppingCriteria", "StoppingCriteriaList", + "SuppressTokensAtBeginLogitsProcessor", + "SuppressTokensLogitsProcessor", "TemperatureLogitsWarper", "TopKLogitsWarper", "TopPLogitsWarper", "TypicalLogitsWarper", + "UnbatchedClassifierFreeGuidanceLogitsProcessor", + "WhisperTimeStampLogitsProcessor", "top_k_top_p_filtering", ] ) @@ -3115,6 +3128,7 @@ [ "TFForcedBOSTokenLogitsProcessor", "TFForcedEOSTokenLogitsProcessor", + "TFForceTokensLogitsProcessor", "TFGenerationMixin", "TFLogitsProcessor", "TFLogitsProcessorList", @@ -3123,6 +3137,8 @@ "TFNoBadWordsLogitsProcessor", "TFNoRepeatNGramLogitsProcessor", "TFRepetitionPenaltyLogitsProcessor", + "TFSuppressTokensAtBeginLogitsProcessor", + "TFSuppressTokensLogitsProcessor", "TFTemperatureLogitsWarper", "TFTopKLogitsWarper", "TFTopPLogitsWarper", @@ -3836,14 +3852,18 @@ [ "FlaxForcedBOSTokenLogitsProcessor", "FlaxForcedEOSTokenLogitsProcessor", + "FlaxForceTokensLogitsProcessor", "FlaxGenerationMixin", "FlaxLogitsProcessor", "FlaxLogitsProcessorList", "FlaxLogitsWarper", "FlaxMinLengthLogitsProcessor", "FlaxTemperatureLogitsWarper", + "FlaxSuppressTokensAtBeginLogitsProcessor", + "FlaxSuppressTokensLogitsProcessor", "FlaxTopKLogitsWarper", "FlaxTopPLogitsWarper", + "FlaxWhisperTimeStampLogitsProcessor", ] ) _import_structure["generation_flax_utils"] = [] @@ -4982,17 +5002,26 @@ TextDatasetForNextSentencePrediction, ) from .generation import ( + AlternatingCodebooksLogitsProcessor, BeamScorer, BeamSearchScorer, + ClassifierFreeGuidanceLogitsProcessor, ConstrainedBeamSearchScorer, Constraint, ConstraintListState, DisjunctiveConstraint, + EncoderNoRepeatNGramLogitsProcessor, + EncoderRepetitionPenaltyLogitsProcessor, + EpsilonLogitsWarper, + EtaLogitsWarper, + ExponentialDecayLengthPenalty, ForcedBOSTokenLogitsProcessor, ForcedEOSTokenLogitsProcessor, + ForceTokensLogitsProcessor, GenerationMixin, HammingDiversityLogitsProcessor, InfNanRemoveLogitsProcessor, + LogitNormalization, LogitsProcessor, LogitsProcessorList, LogitsWarper, @@ -5008,10 +5037,14 @@ SequenceBiasLogitsProcessor, StoppingCriteria, StoppingCriteriaList, + SuppressTokensAtBeginLogitsProcessor, + SuppressTokensLogitsProcessor, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper, TypicalLogitsWarper, + UnbatchedClassifierFreeGuidanceLogitsProcessor, + WhisperTimeStampLogitsProcessor, top_k_top_p_filtering, ) from .modeling_utils import PreTrainedModel @@ -6711,6 +6744,7 @@ from .generation import ( TFForcedBOSTokenLogitsProcessor, TFForcedEOSTokenLogitsProcessor, + TFForceTokensLogitsProcessor, TFGenerationMixin, TFLogitsProcessor, TFLogitsProcessorList, @@ -6719,6 +6753,8 @@ TFNoBadWordsLogitsProcessor, TFNoRepeatNGramLogitsProcessor, TFRepetitionPenaltyLogitsProcessor, + TFSuppressTokensAtBeginLogitsProcessor, + TFSuppressTokensLogitsProcessor, TFTemperatureLogitsWarper, TFTopKLogitsWarper, TFTopPLogitsWarper, @@ -7284,14 +7320,18 @@ from .generation import ( FlaxForcedBOSTokenLogitsProcessor, FlaxForcedEOSTokenLogitsProcessor, + FlaxForceTokensLogitsProcessor, FlaxGenerationMixin, FlaxLogitsProcessor, FlaxLogitsProcessorList, FlaxLogitsWarper, FlaxMinLengthLogitsProcessor, + FlaxSuppressTokensAtBeginLogitsProcessor, + FlaxSuppressTokensLogitsProcessor, FlaxTemperatureLogitsWarper, FlaxTopKLogitsWarper, FlaxTopPLogitsWarper, + FlaxWhisperTimeStampLogitsProcessor, ) from .modeling_flax_utils import FlaxPreTrainedModel diff --git a/src/transformers/generation/__init__.py b/src/transformers/generation/__init__.py index f0da9f514e7af0..a46cb4fa910ada 100644 --- a/src/transformers/generation/__init__.py +++ b/src/transformers/generation/__init__.py @@ -41,12 +41,19 @@ "ConstrainedBeamSearchScorer", ] _import_structure["logits_process"] = [ + "AlternatingCodebooksLogitsProcessor", + "ClassifierFreeGuidanceLogitsProcessor", + "EncoderNoRepeatNGramLogitsProcessor", + "EncoderRepetitionPenaltyLogitsProcessor", "EpsilonLogitsWarper", "EtaLogitsWarper", + "ExponentialDecayLengthPenalty", "ForcedBOSTokenLogitsProcessor", "ForcedEOSTokenLogitsProcessor", + "ForceTokensLogitsProcessor", "HammingDiversityLogitsProcessor", "InfNanRemoveLogitsProcessor", + "LogitNormalization", "LogitsProcessor", "LogitsProcessorList", "LogitsWarper", @@ -57,15 +64,14 @@ "PrefixConstrainedLogitsProcessor", "RepetitionPenaltyLogitsProcessor", "SequenceBiasLogitsProcessor", - "EncoderRepetitionPenaltyLogitsProcessor", + "SuppressTokensLogitsProcessor", + "SuppressTokensAtBeginLogitsProcessor", "TemperatureLogitsWarper", "TopKLogitsWarper", "TopPLogitsWarper", "TypicalLogitsWarper", - "EncoderNoRepeatNGramLogitsProcessor", - "ExponentialDecayLengthPenalty", - "LogitNormalization", "UnbatchedClassifierFreeGuidanceLogitsProcessor", + "WhisperTimeStampLogitsProcessor", ] _import_structure["stopping_criteria"] = [ "MaxNewTokensCriteria", @@ -99,6 +105,7 @@ _import_structure["tf_logits_process"] = [ "TFForcedBOSTokenLogitsProcessor", "TFForcedEOSTokenLogitsProcessor", + "TFForceTokensLogitsProcessor", "TFLogitsProcessor", "TFLogitsProcessorList", "TFLogitsWarper", @@ -106,12 +113,11 @@ "TFNoBadWordsLogitsProcessor", "TFNoRepeatNGramLogitsProcessor", "TFRepetitionPenaltyLogitsProcessor", + "TFSuppressTokensAtBeginLogitsProcessor", + "TFSuppressTokensLogitsProcessor", "TFTemperatureLogitsWarper", "TFTopKLogitsWarper", "TFTopPLogitsWarper", - "TFForceTokensLogitsProcessor", - "TFSuppressTokensAtBeginLogitsProcessor", - "TFSuppressTokensLogitsProcessor", ] _import_structure["tf_utils"] = [ "TFGenerationMixin", @@ -137,13 +143,17 @@ _import_structure["flax_logits_process"] = [ "FlaxForcedBOSTokenLogitsProcessor", "FlaxForcedEOSTokenLogitsProcessor", + "FlaxForceTokensLogitsProcessor", "FlaxLogitsProcessor", "FlaxLogitsProcessorList", "FlaxLogitsWarper", "FlaxMinLengthLogitsProcessor", + "FlaxSuppressTokensAtBeginLogitsProcessor", + "FlaxSuppressTokensLogitsProcessor", "FlaxTemperatureLogitsWarper", "FlaxTopKLogitsWarper", "FlaxTopPLogitsWarper", + "FlaxWhisperTimeStampLogitsProcessor", ] _import_structure["flax_utils"] = [ "FlaxGenerationMixin", @@ -165,6 +175,8 @@ from .beam_constraints import Constraint, ConstraintListState, DisjunctiveConstraint, PhrasalConstraint from .beam_search import BeamHypotheses, BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer from .logits_process import ( + AlternatingCodebooksLogitsProcessor, + ClassifierFreeGuidanceLogitsProcessor, EncoderNoRepeatNGramLogitsProcessor, EncoderRepetitionPenaltyLogitsProcessor, EpsilonLogitsWarper, @@ -172,6 +184,7 @@ ExponentialDecayLengthPenalty, ForcedBOSTokenLogitsProcessor, ForcedEOSTokenLogitsProcessor, + ForceTokensLogitsProcessor, HammingDiversityLogitsProcessor, InfNanRemoveLogitsProcessor, LogitNormalization, @@ -185,11 +198,14 @@ PrefixConstrainedLogitsProcessor, RepetitionPenaltyLogitsProcessor, SequenceBiasLogitsProcessor, + SuppressTokensAtBeginLogitsProcessor, + SuppressTokensLogitsProcessor, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper, TypicalLogitsWarper, UnbatchedClassifierFreeGuidanceLogitsProcessor, + WhisperTimeStampLogitsProcessor, ) from .stopping_criteria import ( MaxLengthCriteria, @@ -261,13 +277,17 @@ from .flax_logits_process import ( FlaxForcedBOSTokenLogitsProcessor, FlaxForcedEOSTokenLogitsProcessor, + FlaxForceTokensLogitsProcessor, FlaxLogitsProcessor, FlaxLogitsProcessorList, FlaxLogitsWarper, FlaxMinLengthLogitsProcessor, + FlaxSuppressTokensAtBeginLogitsProcessor, + FlaxSuppressTokensLogitsProcessor, FlaxTemperatureLogitsWarper, FlaxTopKLogitsWarper, FlaxTopPLogitsWarper, + FlaxWhisperTimeStampLogitsProcessor, ) from .flax_utils import FlaxBeamSearchOutput, FlaxGenerationMixin, FlaxGreedySearchOutput, FlaxSampleOutput else: diff --git a/src/transformers/utils/dummy_flax_objects.py b/src/transformers/utils/dummy_flax_objects.py index 78be4ef747e96a..0f6af902ec26b8 100644 --- a/src/transformers/utils/dummy_flax_objects.py +++ b/src/transformers/utils/dummy_flax_objects.py @@ -16,6 +16,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) +class FlaxForceTokensLogitsProcessor(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + class FlaxGenerationMixin(metaclass=DummyObject): _backends = ["flax"] @@ -51,6 +58,20 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) +class FlaxSuppressTokensAtBeginLogitsProcessor(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxSuppressTokensLogitsProcessor(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + class FlaxTemperatureLogitsWarper(metaclass=DummyObject): _backends = ["flax"] @@ -72,6 +93,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) +class FlaxWhisperTimeStampLogitsProcessor(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + class FlaxPreTrainedModel(metaclass=DummyObject): _backends = ["flax"] diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 5724e689f2fce2..c1cdc3955e97dc 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -79,6 +79,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class AlternatingCodebooksLogitsProcessor(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class BeamScorer(metaclass=DummyObject): _backends = ["torch"] @@ -93,6 +100,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class ClassifierFreeGuidanceLogitsProcessor(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class ConstrainedBeamSearchScorer(metaclass=DummyObject): _backends = ["torch"] @@ -121,6 +135,41 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class EncoderNoRepeatNGramLogitsProcessor(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class EncoderRepetitionPenaltyLogitsProcessor(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class EpsilonLogitsWarper(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class EtaLogitsWarper(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class ExponentialDecayLengthPenalty(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class ForcedBOSTokenLogitsProcessor(metaclass=DummyObject): _backends = ["torch"] @@ -135,6 +184,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class ForceTokensLogitsProcessor(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class GenerationMixin(metaclass=DummyObject): _backends = ["torch"] @@ -156,6 +212,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class LogitNormalization(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class LogitsProcessor(metaclass=DummyObject): _backends = ["torch"] @@ -261,6 +324,20 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class SuppressTokensAtBeginLogitsProcessor(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class SuppressTokensLogitsProcessor(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class TemperatureLogitsWarper(metaclass=DummyObject): _backends = ["torch"] @@ -289,6 +366,20 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class UnbatchedClassifierFreeGuidanceLogitsProcessor(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class WhisperTimeStampLogitsProcessor(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + def top_k_top_p_filtering(*args, **kwargs): requires_backends(top_k_top_p_filtering, ["torch"]) diff --git a/src/transformers/utils/dummy_tf_objects.py b/src/transformers/utils/dummy_tf_objects.py index 46cde8ffbef434..9b1aae44932668 100644 --- a/src/transformers/utils/dummy_tf_objects.py +++ b/src/transformers/utils/dummy_tf_objects.py @@ -30,6 +30,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["tf"]) +class TFForceTokensLogitsProcessor(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + class TFGenerationMixin(metaclass=DummyObject): _backends = ["tf"] @@ -86,6 +93,20 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["tf"]) +class TFSuppressTokensAtBeginLogitsProcessor(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + +class TFSuppressTokensLogitsProcessor(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + class TFTemperatureLogitsWarper(metaclass=DummyObject): _backends = ["tf"]