mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
support the trocr small models (#14893)
* support the trocr small models * resolve conflict * Update docs/source/model_doc/trocr.mdx Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update docs/source/model_doc/trocr.mdx Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update docs/source/model_doc/trocr.mdx Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/trocr/processing_trocr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/trocr/processing_trocr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/trocr/processing_trocr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/trocr/processing_trocr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * fix unexpected indent in processing_trocr.py * Update src/transformers/models/trocr/processing_trocr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * update the docstring of processing_trocr * remove extra space Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>
This commit is contained in:
parent
42d57549b8
commit
b2c477fc6d
2 changed files with 14 additions and 12 deletions
|
|
@ -55,9 +55,9 @@ Tips:
|
|||
TrOCR's [`VisionEncoderDecoder`] model accepts images as input and makes use of
|
||||
[`~generation_utils.GenerationMixin.generate`] to autoregressively generate text given the input image.
|
||||
|
||||
The [`ViTFeatureExtractor`] class is responsible for preprocessing the input image and
|
||||
[`RobertaTokenizer`] decodes the generated target tokens to the target string. The
|
||||
[`TrOCRProcessor`] wraps [`ViTFeatureExtractor`] and [`RobertaTokenizer`]
|
||||
The [`ViTFeatureExtractor`/`DeiTFeatureExtractor`] class is responsible for preprocessing the input image and
|
||||
[`RobertaTokenizer`/`XLMRobertaTokenizer`] decodes the generated target tokens to the target string. The
|
||||
[`TrOCRProcessor`] wraps [`ViTFeatureExtractor`/`DeiTFeatureExtractor`] and [`RobertaTokenizer`/`XLMRobertaTokenizer`]
|
||||
into a single instance to both extract the input features and decode the predicted token ids.
|
||||
|
||||
- Step-by-step Optical Character Recognition (OCR)
|
||||
|
|
|
|||
|
|
@ -20,22 +20,24 @@ from contextlib import contextmanager
|
|||
from transformers.feature_extraction_utils import FeatureExtractionMixin
|
||||
from transformers.models.roberta.tokenization_roberta import RobertaTokenizer
|
||||
from transformers.models.roberta.tokenization_roberta_fast import RobertaTokenizerFast
|
||||
from transformers.models.xlm_roberta.tokenization_xlm_roberta import XLMRobertaTokenizer
|
||||
from transformers.models.xlm_roberta.tokenization_xlm_roberta_fast import XLMRobertaTokenizerFast
|
||||
|
||||
from ..auto.feature_extraction_auto import AutoFeatureExtractor
|
||||
from transformers import AutoTokenizer, AutoFeatureExtractor
|
||||
|
||||
|
||||
class TrOCRProcessor:
|
||||
r"""
|
||||
Constructs a TrOCR processor which wraps a vision feature extractor and a TrOCR tokenizer into a single processor.
|
||||
|
||||
[`TrOCRProcessor`] offers all the functionalities of [`AutoFeatureExtractor`] and [`RobertaTokenizer`]. See the
|
||||
[`TrOCRProcessor`] offers all the functionalities of [`ViTFeatureExtractor`/`DeiTFeatureExtractor`] and [`RobertaTokenizer`/`XLMRobertaTokenizer`]. See the
|
||||
[`~TrOCRProcessor.__call__`] and [`~TrOCRProcessor.decode`] for more information.
|
||||
|
||||
Args:
|
||||
feature_extractor ([`AutoFeatureExtractor`]):
|
||||
An instance of [`AutoFeatureExtractor`]. The feature extractor is a required input.
|
||||
tokenizer ([`RobertaTokenizer`]):
|
||||
An instance of [`RobertaTokenizer`]. The tokenizer is a required input.
|
||||
feature_extractor ([`ViTFeatureExtractor`/`DeiTFeatureExtractor`]):
|
||||
An instance of [`ViTFeatureExtractor`/`DeiTFeatureExtractor`]. The feature extractor is a required input.
|
||||
tokenizer ([`RobertaTokenizer`/`XLMRobertaTokenizer`]):
|
||||
An instance of [`RobertaTokenizer`/`XLMRobertaTokenizer`]. The tokenizer is a required input.
|
||||
"""
|
||||
|
||||
def __init__(self, feature_extractor, tokenizer):
|
||||
|
|
@ -43,9 +45,9 @@ class TrOCRProcessor:
|
|||
raise ValueError(
|
||||
f"`feature_extractor` has to be of type {FeatureExtractionMixin.__class__}, but is {type(feature_extractor)}"
|
||||
)
|
||||
if not (isinstance(tokenizer, RobertaTokenizer) or (isinstance(tokenizer, RobertaTokenizerFast))):
|
||||
if not isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast, XLMRobertaTokenizer, XLMRobertaTokenizerFast)):
|
||||
raise ValueError(
|
||||
f"`tokenizer` has to be of type {RobertaTokenizer.__class__} or {RobertaTokenizerFast.__class__}, but is {type(tokenizer)}"
|
||||
f"`tokenizer` has to be of type {RobertaTokenizer.__class__} or {RobertaTokenizerFast.__class__} or {XLMRobertaTokenizer.__class__} or {XLMRobertaTokenizerFast.__class__}, but is {type(tokenizer)}"
|
||||
)
|
||||
|
||||
self.feature_extractor = feature_extractor
|
||||
|
|
@ -103,7 +105,7 @@ class TrOCRProcessor:
|
|||
[`PreTrainedTokenizer`]
|
||||
"""
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
tokenizer = RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
return cls(feature_extractor=feature_extractor, tokenizer=tokenizer)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue