Have dummy processors have a from_pretrained method (#12145)

This commit is contained in:
Lysandre Debut 2021-06-15 14:39:05 +02:00 committed by GitHub
parent 9b393240a2
commit d07b540a37
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 61 additions and 0 deletions

View file

@ -6,11 +6,19 @@ class FlaxLogitsProcessor:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxLogitsProcessorList:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxLogitsWarper:
def __init__(self, *args, **kwargs):

View file

@ -127,31 +127,55 @@ class ForcedBOSTokenLogitsProcessor:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class ForcedEOSTokenLogitsProcessor:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class HammingDiversityLogitsProcessor:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class InfNanRemoveLogitsProcessor:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class LogitsProcessor:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class LogitsProcessorList:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class LogitsWarper:
def __init__(self, *args, **kwargs):
@ -162,26 +186,46 @@ class MinLengthLogitsProcessor:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class NoBadWordsLogitsProcessor:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class NoRepeatNGramLogitsProcessor:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class PrefixConstrainedLogitsProcessor:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class RepetitionPenaltyLogitsProcessor:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class TemperatureLogitsWarper:
def __init__(self, *args, **kwargs):

View file

@ -5,3 +5,7 @@ from ..file_utils import requires_backends
class Speech2TextProcessor:
def __init__(self, *args, **kwargs):
requires_backends(self, ["sentencepiece", "speech"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["sentencepiece", "speech"])

View file

@ -16,6 +16,10 @@ class CLIPProcessor:
def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["vision"])
class DeiTFeatureExtractor:
def __init__(self, *args, **kwargs):

View file

@ -115,6 +115,7 @@ def create_dummy_object(name, backend_name):
"ForTokenClassification",
"Model",
"Tokenizer",
"Processor",
]
if name.isupper():
return DUMMY_CONSTANT.format(name)