diff --git a/docs/source/model_doc/trocr.mdx b/docs/source/model_doc/trocr.mdx index b8cdbaeb6..494895c1a 100644 --- a/docs/source/model_doc/trocr.mdx +++ b/docs/source/model_doc/trocr.mdx @@ -55,9 +55,9 @@ Tips: TrOCR's [`VisionEncoderDecoder`] model accepts images as input and makes use of [`~generation_utils.GenerationMixin.generate`] to autoregressively generate text given the input image. -The [`ViTFeatureExtractor`] class is responsible for preprocessing the input image and -[`RobertaTokenizer`] decodes the generated target tokens to the target string. The -[`TrOCRProcessor`] wraps [`ViTFeatureExtractor`] and [`RobertaTokenizer`] +The [`ViTFeatureExtractor`/`DeiTFeatureExtractor`] class is responsible for preprocessing the input image and +[`RobertaTokenizer`/`XLMRobertaTokenizer`] decodes the generated target tokens to the target string. The +[`TrOCRProcessor`] wraps [`ViTFeatureExtractor`/`DeiTFeatureExtractor`] and [`RobertaTokenizer`/`XLMRobertaTokenizer`] into a single instance to both extract the input features and decode the predicted token ids. - Step-by-step Optical Character Recognition (OCR) diff --git a/src/transformers/models/trocr/processing_trocr.py b/src/transformers/models/trocr/processing_trocr.py index 3166cbae2..6022e54d2 100644 --- a/src/transformers/models/trocr/processing_trocr.py +++ b/src/transformers/models/trocr/processing_trocr.py @@ -20,22 +20,24 @@ from contextlib import contextmanager from transformers.feature_extraction_utils import FeatureExtractionMixin from transformers.models.roberta.tokenization_roberta import RobertaTokenizer from transformers.models.roberta.tokenization_roberta_fast import RobertaTokenizerFast +from transformers.models.xlm_roberta.tokenization_xlm_roberta import XLMRobertaTokenizer +from transformers.models.xlm_roberta.tokenization_xlm_roberta_fast import XLMRobertaTokenizerFast -from ..auto.feature_extraction_auto import AutoFeatureExtractor +from transformers import AutoTokenizer, AutoFeatureExtractor class TrOCRProcessor: r""" Constructs a TrOCR processor which wraps a vision feature extractor and a TrOCR tokenizer into a single processor. - [`TrOCRProcessor`] offers all the functionalities of [`AutoFeatureExtractor`] and [`RobertaTokenizer`]. See the + [`TrOCRProcessor`] offers all the functionalities of [`ViTFeatureExtractor`/`DeiTFeatureExtractor`] and [`RobertaTokenizer`/`XLMRobertaTokenizer`]. See the [`~TrOCRProcessor.__call__`] and [`~TrOCRProcessor.decode`] for more information. Args: - feature_extractor ([`AutoFeatureExtractor`]): - An instance of [`AutoFeatureExtractor`]. The feature extractor is a required input. - tokenizer ([`RobertaTokenizer`]): - An instance of [`RobertaTokenizer`]. The tokenizer is a required input. + feature_extractor ([`ViTFeatureExtractor`/`DeiTFeatureExtractor`]): + An instance of [`ViTFeatureExtractor`/`DeiTFeatureExtractor`]. The feature extractor is a required input. + tokenizer ([`RobertaTokenizer`/`XLMRobertaTokenizer`]): + An instance of [`RobertaTokenizer`/`XLMRobertaTokenizer`]. The tokenizer is a required input. """ def __init__(self, feature_extractor, tokenizer): @@ -43,9 +45,9 @@ class TrOCRProcessor: raise ValueError( f"`feature_extractor` has to be of type {FeatureExtractionMixin.__class__}, but is {type(feature_extractor)}" ) - if not (isinstance(tokenizer, RobertaTokenizer) or (isinstance(tokenizer, RobertaTokenizerFast))): + if not isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast, XLMRobertaTokenizer, XLMRobertaTokenizerFast)): raise ValueError( - f"`tokenizer` has to be of type {RobertaTokenizer.__class__} or {RobertaTokenizerFast.__class__}, but is {type(tokenizer)}" + f"`tokenizer` has to be of type {RobertaTokenizer.__class__} or {RobertaTokenizerFast.__class__} or {XLMRobertaTokenizer.__class__} or {XLMRobertaTokenizerFast.__class__}, but is {type(tokenizer)}" ) self.feature_extractor = feature_extractor @@ -103,7 +105,7 @@ class TrOCRProcessor: [`PreTrainedTokenizer`] """ feature_extractor = AutoFeatureExtractor.from_pretrained(pretrained_model_name_or_path, **kwargs) - tokenizer = RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs) + tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs) return cls(feature_extractor=feature_extractor, tokenizer=tokenizer)