mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Add Flax image captioning example (#14864)
* add image captioning example * update README * fix style & quality * simplify * apply review suggestions * Apply suggestions from code review Co-authored-by: Suraj Patil <surajp815@gmail.com> * Apply suggestions from code review Co-authored-by: Suraj Patil <surajp815@gmail.com> * Apply review suggestions * add comments about using np instead jax array * remove unused lines * add model creation script * only support from_pretrained * fix style * fix * not use cache_dir when creating model * fix tokenizer creation * update README * fix quality * apply suggestion * simplify some blocks * Update examples/flax/image-captioning/README.md * Update examples/flax/image-captioning/run_image_captioning_flax.py Co-authored-by: Suraj Patil <surajp815@gmail.com> * apply suggestion Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
parent
2e9af29494
commit
9f89fa02ed
3 changed files with 1388 additions and 0 deletions
68
examples/flax/image-captioning/README.md
Normal file
68
examples/flax/image-captioning/README.md
Normal file
|
|
@ -0,0 +1,68 @@
|
|||
# Image Captioning (vision-encoder-text-decoder model) training example
|
||||
|
||||
The following example showcases how to finetune a vision-encoder-text-decoder model for image captioning
|
||||
using the JAX/Flax backend, leveraging 🤗 Transformers library's [FlaxVisionEncoderDecoderModel](https://huggingface.co/docs/transformers/model_doc/visionencoderdecoder#transformers.FlaxVisionEncoderDecoderModel).
|
||||
|
||||
JAX/Flax allows you to trace pure functions and compile them into efficient, fused accelerator code on both GPU and TPU.
|
||||
Models written in JAX/Flax are **immutable** and updated in a purely functional
|
||||
way which enables simple and efficient model parallelism.
|
||||
|
||||
`run_image_captioning_flax.py` is a lightweight example of how to download and preprocess a dataset from the 🤗 Datasets
|
||||
library or use your own files (jsonlines or csv), then fine-tune one of the architectures above on it.
|
||||
|
||||
For custom datasets in `jsonlines` format please see: https://huggingface.co/docs/datasets/loading_datasets.html#json-files and you also will find examples of these below.
|
||||
|
||||
### Download COCO dataset (2017)
|
||||
This example uses COCO dataset (2017) through a custom dataset script, which requires users to manually download the
|
||||
COCO dataset before training.
|
||||
|
||||
```bash
|
||||
mkdir data
|
||||
cd data
|
||||
wget http://images.cocodataset.org/zips/train2017.zip
|
||||
wget http://images.cocodataset.org/zips/val2017.zip
|
||||
wget http://images.cocodataset.org/zips/test2017.zip
|
||||
wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip
|
||||
wget http://images.cocodataset.org/annotations/image_info_test2017.zip
|
||||
cd ..
|
||||
```
|
||||
|
||||
### Create a model from a vision encoder model and a text decoder model
|
||||
Next, we create a [FlaxVisionEncoderDecoderModel](https://huggingface.co/docs/transformers/model_doc/visionencoderdecoder#transformers.FlaxVisionEncoderDecoderModel) instance from a pre-trained vision encoder ([ViT](https://huggingface.co/docs/transformers/model_doc/vit#transformers.FlaxViTModel)) and a pre-trained text decoder ([GPT2](https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.FlaxGPT2Model)):
|
||||
|
||||
```bash
|
||||
python3 create_model_from_encoder_decoder_models.py \
|
||||
--output_dir model \
|
||||
--encoder_model_name_or_path google/vit-base-patch16-224-in21k \
|
||||
--decoder_model_name_or_path gpt2
|
||||
```
|
||||
|
||||
### Train the model
|
||||
Finally, we can run the example script to train the model:
|
||||
|
||||
```bash
|
||||
python3 run_image_captioning_flax.py \
|
||||
--output_dir ./image-captioning-training-results \
|
||||
--model_name_or_path model \
|
||||
--dataset_name ydshieh/coco_dataset_script \
|
||||
--dataset_config_name=2017 \
|
||||
--data_dir $PWD/data \
|
||||
--image_column image_path \
|
||||
--caption_column caption \
|
||||
--do_train --do_eval --predict_with_generate \
|
||||
--num_train_epochs 1 \
|
||||
--eval_steps 500 \
|
||||
--learning_rate 3e-5 --warmup_steps 0 \
|
||||
--per_device_train_batch_size 32 \
|
||||
--per_device_eval_batch_size 32 \
|
||||
--overwrite_output_dir \
|
||||
--max_target_length 32 \
|
||||
--num_beams 8 \
|
||||
--preprocessing_num_workers 16 \
|
||||
--logging_steps 10 \
|
||||
--block_size 16384 \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
This should finish in about 1h30 on Cloud TPU, with validation loss and ROUGE2 score of 2.0153 and 14.64 respectively
|
||||
after 1 epoch. Training statistics can be accessed on [Models](https://huggingface.co/ydshieh/image-captioning-training-results/tensorboard).
|
||||
|
|
@ -0,0 +1,118 @@
|
|||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Team All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Create a VisionEncoderDecoderModel instance from pretrained encoder/decoder models.
|
||||
|
||||
The cross-attention will be randomly initialized.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoFeatureExtractor,
|
||||
AutoTokenizer,
|
||||
FlaxVisionEncoderDecoderModel,
|
||||
HfArgumentParser,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
"""
|
||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
|
||||
"""
|
||||
|
||||
output_dir: str = field(
|
||||
metadata={"help": "The output directory where the model will be written."},
|
||||
)
|
||||
encoder_model_name_or_path: str = field(
|
||||
metadata={
|
||||
"help": "The encoder model checkpoint for weights initialization."
|
||||
"Don't set if you want to train an encoder model from scratch."
|
||||
},
|
||||
)
|
||||
decoder_model_name_or_path: str = field(
|
||||
metadata={
|
||||
"help": "The decoder model checkpoint for weights initialization."
|
||||
"Don't set if you want to train a decoder model from scratch."
|
||||
},
|
||||
)
|
||||
encoder_config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained encoder config name or path if not the same as encoder_model_name"}
|
||||
)
|
||||
decoder_config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained decoder config name or path if not the same as decoder_model_name"}
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = HfArgumentParser((ModelArguments,))
|
||||
(model_args,) = parser.parse_args_into_dataclasses()
|
||||
|
||||
# Load pretrained model and tokenizer
|
||||
|
||||
# Use explicit specified encoder config
|
||||
if model_args.encoder_config_name:
|
||||
encoder_config = AutoConfig.from_pretrained(model_args.encoder_config_name)
|
||||
# Use pretrained encoder model's config
|
||||
else:
|
||||
encoder_config = AutoConfig.from_pretrained(model_args.encoder_model_name_or_path)
|
||||
|
||||
# Use explicit specified decoder config
|
||||
if model_args.decoder_config_name:
|
||||
decoder_config = AutoConfig.from_pretrained(model_args.decoder_config_name)
|
||||
# Use pretrained decoder model's config
|
||||
else:
|
||||
decoder_config = AutoConfig.from_pretrained(model_args.decoder_model_name_or_path)
|
||||
|
||||
# necessary for `from_encoder_decoder_pretrained` when `decoder_config` is passed
|
||||
decoder_config.is_decoder = True
|
||||
decoder_config.add_cross_attention = True
|
||||
|
||||
model = FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
|
||||
encoder_pretrained_model_name_or_path=model_args.encoder_model_name_or_path,
|
||||
decoder_pretrained_model_name_or_path=model_args.decoder_model_name_or_path,
|
||||
encoder_config=encoder_config,
|
||||
decoder_config=decoder_config,
|
||||
)
|
||||
|
||||
# GPT2 only has bos/eos tokens but not decoder_start/pad tokens
|
||||
decoder_start_token_id = decoder_config.decoder_start_token_id
|
||||
pad_token_id = decoder_config.pad_token_id
|
||||
if decoder_start_token_id is None:
|
||||
decoder_start_token_id = decoder_config.bos_token_id
|
||||
if pad_token_id is None:
|
||||
pad_token_id = decoder_config.eos_token_id
|
||||
|
||||
# This is necessary to make Flax's generate() work
|
||||
model.config.eos_token_id = decoder_config.eos_token_id
|
||||
model.config.decoder_start_token_id = decoder_start_token_id
|
||||
model.config.pad_token_id = pad_token_id
|
||||
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained(model_args.encoder_model_name_or_path)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_args.decoder_model_name_or_path)
|
||||
tokenizer.pad_token = tokenizer.convert_ids_to_tokens(model.config.pad_token_id)
|
||||
|
||||
model.save_pretrained(model_args.output_dir)
|
||||
feature_extractor.save_pretrained(model_args.output_dir)
|
||||
tokenizer.save_pretrained(model_args.output_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1202
examples/flax/image-captioning/run_image_captioning_flax.py
Normal file
1202
examples/flax/image-captioning/run_image_captioning_flax.py
Normal file
File diff suppressed because it is too large
Load diff
Loading…
Reference in a new issue