mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
parent
80dbbd103c
commit
aaa969e97d
3 changed files with 0 additions and 140 deletions
|
|
@ -1,137 +0,0 @@
|
|||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from argparse import ArgumentParser, Namespace
|
||||
|
||||
from ..utils import logging
|
||||
from . import BaseTransformersCLICommand
|
||||
|
||||
|
||||
MAX_ERROR = 5e-5 # larger error tolerance than in our internal tests, to avoid flaky user-facing errors
|
||||
|
||||
|
||||
def convert_command_factory(args: Namespace):
|
||||
"""
|
||||
Factory function used to convert a model PyTorch checkpoint in a TensorFlow 2 checkpoint.
|
||||
|
||||
Returns: ServeCommand
|
||||
"""
|
||||
return PTtoTFCommand(
|
||||
args.model_name,
|
||||
args.local_dir,
|
||||
args.max_error,
|
||||
args.new_weights,
|
||||
args.no_pr,
|
||||
args.push,
|
||||
args.extra_commit_description,
|
||||
args.override_model_class,
|
||||
)
|
||||
|
||||
|
||||
class PTtoTFCommand(BaseTransformersCLICommand):
|
||||
@staticmethod
|
||||
def register_subcommand(parser: ArgumentParser):
|
||||
"""
|
||||
Register this command to argparse so it's available for the transformer-cli
|
||||
|
||||
Args:
|
||||
parser: Root parser to register command-specific arguments
|
||||
"""
|
||||
train_parser = parser.add_parser(
|
||||
"pt-to-tf",
|
||||
help=(
|
||||
"CLI tool to run convert a transformers model from a PyTorch checkpoint to a TensorFlow checkpoint."
|
||||
" Can also be used to validate existing weights without opening PRs, with --no-pr."
|
||||
),
|
||||
)
|
||||
train_parser.add_argument(
|
||||
"--model-name",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The model name, including owner/organization, as seen on the hub.",
|
||||
)
|
||||
train_parser.add_argument(
|
||||
"--local-dir",
|
||||
type=str,
|
||||
default="",
|
||||
help="Optional local directory of the model repository. Defaults to /tmp/{model_name}",
|
||||
)
|
||||
train_parser.add_argument(
|
||||
"--max-error",
|
||||
type=float,
|
||||
default=MAX_ERROR,
|
||||
help=(
|
||||
f"Maximum error tolerance. Defaults to {MAX_ERROR}. This flag should be avoided, use at your own risk."
|
||||
),
|
||||
)
|
||||
train_parser.add_argument(
|
||||
"--new-weights",
|
||||
action="store_true",
|
||||
help="Optional flag to create new TensorFlow weights, even if they already exist.",
|
||||
)
|
||||
train_parser.add_argument(
|
||||
"--no-pr", action="store_true", help="Optional flag to NOT open a PR with converted weights."
|
||||
)
|
||||
train_parser.add_argument(
|
||||
"--push",
|
||||
action="store_true",
|
||||
help="Optional flag to push the weights directly to `main` (requires permissions)",
|
||||
)
|
||||
train_parser.add_argument(
|
||||
"--extra-commit-description",
|
||||
type=str,
|
||||
default="",
|
||||
help="Optional additional commit description to use when opening a PR (e.g. to tag the owner).",
|
||||
)
|
||||
train_parser.add_argument(
|
||||
"--override-model-class",
|
||||
type=str,
|
||||
default=None,
|
||||
help="If you think you know better than the auto-detector, you can specify the model class here. "
|
||||
"Can be either an AutoModel class or a specific model class like BertForSequenceClassification.",
|
||||
)
|
||||
train_parser.set_defaults(func=convert_command_factory)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
local_dir: str,
|
||||
max_error: float,
|
||||
new_weights: bool,
|
||||
no_pr: bool,
|
||||
push: bool,
|
||||
extra_commit_description: str,
|
||||
override_model_class: str,
|
||||
*args,
|
||||
):
|
||||
self._logger = logging.get_logger("transformers-cli/pt_to_tf")
|
||||
self._model_name = model_name
|
||||
self._local_dir = local_dir if local_dir else os.path.join("/tmp", model_name)
|
||||
self._max_error = max_error
|
||||
self._new_weights = new_weights
|
||||
self._no_pr = no_pr
|
||||
self._push = push
|
||||
self._extra_commit_description = extra_commit_description
|
||||
self._override_model_class = override_model_class
|
||||
|
||||
def run(self):
|
||||
# TODO (joao): delete file in v4.47
|
||||
raise NotImplementedError(
|
||||
"\n\nConverting PyTorch weights to TensorFlow weights was removed in v4.43. "
|
||||
"Instead, we recommend that you convert PyTorch weights to Safetensors, an improved "
|
||||
"format that can be loaded by any framework, including TensorFlow. For more information, "
|
||||
"please see the Safetensors conversion guide: "
|
||||
"https://huggingface.co/docs/safetensors/en/convert-weights\n\n"
|
||||
)
|
||||
|
|
@ -20,7 +20,6 @@ from .convert import ConvertCommand
|
|||
from .download import DownloadCommand
|
||||
from .env import EnvironmentCommand
|
||||
from .lfs import LfsCommands
|
||||
from .pt_to_tf import PTtoTFCommand
|
||||
from .run import RunCommand
|
||||
from .serving import ServeCommand
|
||||
from .user import UserCommands
|
||||
|
|
@ -39,7 +38,6 @@ def main():
|
|||
UserCommands.register_subcommand(commands_parser)
|
||||
AddNewModelLikeCommand.register_subcommand(commands_parser)
|
||||
LfsCommands.register_subcommand(commands_parser)
|
||||
PTtoTFCommand.register_subcommand(commands_parser)
|
||||
|
||||
# Let's go
|
||||
args = parser.parse_args()
|
||||
|
|
|
|||
|
|
@ -351,7 +351,6 @@ src/transformers/commands/convert.py
|
|||
src/transformers/commands/download.py
|
||||
src/transformers/commands/env.py
|
||||
src/transformers/commands/lfs.py
|
||||
src/transformers/commands/pt_to_tf.py
|
||||
src/transformers/commands/run.py
|
||||
src/transformers/commands/serving.py
|
||||
src/transformers/commands/train.py
|
||||
|
|
|
|||
Loading…
Reference in a new issue