mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
[JAX] Replace uses of jnp.array in types with jnp.ndarray. (#26703)
`jnp.array` is a function, not a type: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.array.html so it never makes sense to use `jnp.array` in a type annotation. Presumably the intent was to write `jnp.ndarray` aka `jax.Array`. Co-authored-by: Peter Hawkins <phawkins@google.com>
This commit is contained in:
parent
3eceaa3637
commit
fc63914399
25 changed files with 28 additions and 28 deletions
|
|
@ -381,7 +381,7 @@ def write_metric(summary_writer, metrics, train_time, step, metric_key_prefix="t
|
|||
|
||||
def create_learning_rate_fn(
|
||||
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
|
||||
) -> Callable[[int], jnp.array]:
|
||||
) -> Callable[[int], jnp.ndarray]:
|
||||
"""Returns a linear warmup, linear_decay learning rate function."""
|
||||
steps_per_epoch = train_ds_size // train_batch_size
|
||||
num_train_steps = steps_per_epoch * num_train_epochs
|
||||
|
|
|
|||
|
|
@ -326,7 +326,7 @@ def write_eval_metric(summary_writer, eval_metrics, step):
|
|||
|
||||
def create_learning_rate_fn(
|
||||
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
|
||||
) -> Callable[[int], jnp.array]:
|
||||
) -> Callable[[int], jnp.ndarray]:
|
||||
"""Returns a linear warmup, linear_decay learning rate function."""
|
||||
steps_per_epoch = train_ds_size // train_batch_size
|
||||
num_train_steps = steps_per_epoch * num_train_epochs
|
||||
|
|
|
|||
|
|
@ -389,7 +389,7 @@ def create_train_state(
|
|||
# region Create learning rate function
|
||||
def create_learning_rate_fn(
|
||||
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
|
||||
) -> Callable[[int], jnp.array]:
|
||||
) -> Callable[[int], jnp.ndarray]:
|
||||
"""Returns a linear warmup, linear_decay learning rate function."""
|
||||
steps_per_epoch = train_ds_size // train_batch_size
|
||||
num_train_steps = steps_per_epoch * num_train_epochs
|
||||
|
|
|
|||
|
|
@ -360,7 +360,7 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
|
|||
|
||||
def create_learning_rate_fn(
|
||||
num_train_steps: int, num_warmup_steps: int, learning_rate: float
|
||||
) -> Callable[[int], jnp.array]:
|
||||
) -> Callable[[int], jnp.ndarray]:
|
||||
"""Returns a linear warmup, linear_decay learning rate function."""
|
||||
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
|
||||
decay_fn = optax.linear_schedule(
|
||||
|
|
|
|||
|
|
@ -409,7 +409,7 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
|
|||
|
||||
def create_learning_rate_fn(
|
||||
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
|
||||
) -> Callable[[int], jnp.array]:
|
||||
) -> Callable[[int], jnp.ndarray]:
|
||||
"""Returns a linear warmup, linear_decay learning rate function."""
|
||||
steps_per_epoch = train_ds_size // train_batch_size
|
||||
num_train_steps = steps_per_epoch * num_train_epochs
|
||||
|
|
|
|||
|
|
@ -288,7 +288,7 @@ def create_train_state(
|
|||
|
||||
def create_learning_rate_fn(
|
||||
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
|
||||
) -> Callable[[int], jnp.array]:
|
||||
) -> Callable[[int], jnp.ndarray]:
|
||||
"""Returns a linear warmup, linear_decay learning rate function."""
|
||||
steps_per_epoch = train_ds_size // train_batch_size
|
||||
num_train_steps = steps_per_epoch * num_train_epochs
|
||||
|
|
|
|||
|
|
@ -340,7 +340,7 @@ def create_train_state(
|
|||
|
||||
def create_learning_rate_fn(
|
||||
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
|
||||
) -> Callable[[int], jnp.array]:
|
||||
) -> Callable[[int], jnp.ndarray]:
|
||||
"""Returns a linear warmup, linear_decay learning rate function."""
|
||||
steps_per_epoch = train_ds_size // train_batch_size
|
||||
num_train_steps = steps_per_epoch * num_train_epochs
|
||||
|
|
|
|||
|
|
@ -249,7 +249,7 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
|
|||
|
||||
def create_learning_rate_fn(
|
||||
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
|
||||
) -> Callable[[int], jnp.array]:
|
||||
) -> Callable[[int], jnp.ndarray]:
|
||||
"""Returns a linear warmup, linear_decay learning rate function."""
|
||||
steps_per_epoch = train_ds_size // train_batch_size
|
||||
num_train_steps = steps_per_epoch * num_train_epochs
|
||||
|
|
|
|||
|
|
@ -283,7 +283,7 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
|
|||
|
||||
def create_learning_rate_fn(
|
||||
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
|
||||
) -> Callable[[int], jnp.array]:
|
||||
) -> Callable[[int], jnp.ndarray]:
|
||||
"""Returns a linear warmup, linear_decay learning rate function."""
|
||||
steps_per_epoch = train_ds_size // train_batch_size
|
||||
num_train_steps = steps_per_epoch * num_train_epochs
|
||||
|
|
|
|||
|
|
@ -214,7 +214,7 @@ def write_eval_metric(summary_writer, eval_metrics, step):
|
|||
|
||||
def create_learning_rate_fn(
|
||||
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
|
||||
) -> Callable[[int], jnp.array]:
|
||||
) -> Callable[[int], jnp.ndarray]:
|
||||
"""Returns a linear warmup, linear_decay learning rate function."""
|
||||
steps_per_epoch = train_ds_size // train_batch_size
|
||||
num_train_steps = steps_per_epoch * num_train_epochs
|
||||
|
|
|
|||
|
|
@ -217,7 +217,7 @@ BART_DECODE_INPUTS_DOCSTRING = r"""
|
|||
"""
|
||||
|
||||
|
||||
def shift_tokens_right(input_ids: jnp.array, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
|
||||
def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
|
||||
"""
|
||||
Shift input ids one token to the right.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -295,7 +295,7 @@ class FlaxBertSelfAttention(nn.Module):
|
|||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
key_value_states: Optional[jnp.array] = None,
|
||||
key_value_states: Optional[jnp.ndarray] = None,
|
||||
init_cache: bool = False,
|
||||
deterministic=True,
|
||||
output_attentions: bool = False,
|
||||
|
|
|
|||
|
|
@ -316,7 +316,7 @@ class FlaxBigBirdSelfAttention(nn.Module):
|
|||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
key_value_states: Optional[jnp.array] = None,
|
||||
key_value_states: Optional[jnp.ndarray] = None,
|
||||
init_cache: bool = False,
|
||||
deterministic=True,
|
||||
output_attentions: bool = False,
|
||||
|
|
|
|||
|
|
@ -204,7 +204,7 @@ BLENDERBOT_DECODE_INPUTS_DOCSTRING = r"""
|
|||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
|
||||
def shift_tokens_right(input_ids: jnp.array, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
|
||||
def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
|
||||
"""
|
||||
Shift input ids one token to the right.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -216,7 +216,7 @@ BLENDERBOT_SMALL_DECODE_INPUTS_DOCSTRING = r"""
|
|||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
|
||||
def shift_tokens_right(input_ids: jnp.array, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
|
||||
def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
|
||||
"""
|
||||
Shift input ids one token to the right.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -263,7 +263,7 @@ class FlaxElectraSelfAttention(nn.Module):
|
|||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
key_value_states: Optional[jnp.array] = None,
|
||||
key_value_states: Optional[jnp.ndarray] = None,
|
||||
init_cache: bool = False,
|
||||
deterministic=True,
|
||||
output_attentions: bool = False,
|
||||
|
|
@ -1228,13 +1228,13 @@ class FlaxElectraSequenceSummary(nn.Module):
|
|||
Compute a single vector summary of a sequence hidden states.
|
||||
|
||||
Args:
|
||||
hidden_states (`jnp.array` of shape `[batch_size, seq_len, hidden_size]`):
|
||||
hidden_states (`jnp.ndarray` of shape `[batch_size, seq_len, hidden_size]`):
|
||||
The hidden states of the last layer.
|
||||
cls_index (`jnp.array` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
|
||||
cls_index (`jnp.ndarray` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
|
||||
Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.
|
||||
|
||||
Returns:
|
||||
`jnp.array`: The summary of the sequence hidden states.
|
||||
`jnp.ndarray`: The summary of the sequence hidden states.
|
||||
"""
|
||||
# NOTE: this doest "first" type summary always
|
||||
output = hidden_states[:, 0]
|
||||
|
|
|
|||
|
|
@ -56,7 +56,7 @@ remat = nn_partitioning.remat
|
|||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
|
||||
def shift_tokens_right(input_ids: jnp.array, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
|
||||
def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
|
||||
"""
|
||||
Shift input ids one token to the right.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -227,7 +227,7 @@ def create_sinusoidal_positions(n_pos, dim):
|
|||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
|
||||
def shift_tokens_right(input_ids: jnp.array, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
|
||||
def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
|
||||
"""
|
||||
Shift input ids one token to the right.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ _CONFIG_FOR_DOC = "T5Config"
|
|||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
|
||||
def shift_tokens_right(input_ids: jnp.array, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
|
||||
def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
|
||||
"""
|
||||
Shift input ids one token to the right.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -210,7 +210,7 @@ PEGASUS_DECODE_INPUTS_DOCSTRING = r"""
|
|||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
|
||||
def shift_tokens_right(input_ids: jnp.array, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
|
||||
def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
|
||||
"""
|
||||
Shift input ids one token to the right.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -256,7 +256,7 @@ class FlaxRobertaSelfAttention(nn.Module):
|
|||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
key_value_states: Optional[jnp.array] = None,
|
||||
key_value_states: Optional[jnp.ndarray] = None,
|
||||
init_cache: bool = False,
|
||||
deterministic=True,
|
||||
output_attentions: bool = False,
|
||||
|
|
|
|||
|
|
@ -258,7 +258,7 @@ class FlaxRobertaPreLayerNormSelfAttention(nn.Module):
|
|||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
key_value_states: Optional[jnp.array] = None,
|
||||
key_value_states: Optional[jnp.ndarray] = None,
|
||||
init_cache: bool = False,
|
||||
deterministic=True,
|
||||
output_attentions: bool = False,
|
||||
|
|
|
|||
|
|
@ -56,7 +56,7 @@ remat = nn_partitioning.remat
|
|||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
|
||||
def shift_tokens_right(input_ids: jnp.array, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
|
||||
def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
|
||||
"""
|
||||
Shift input ids one token to the right.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -266,7 +266,7 @@ class FlaxXLMRobertaSelfAttention(nn.Module):
|
|||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
key_value_states: Optional[jnp.array] = None,
|
||||
key_value_states: Optional[jnp.ndarray] = None,
|
||||
init_cache: bool = False,
|
||||
deterministic=True,
|
||||
output_attentions: bool = False,
|
||||
|
|
|
|||
|
|
@ -251,7 +251,7 @@ class Flax{{cookiecutter.camelcase_modelname}}SelfAttention(nn.Module):
|
|||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
key_value_states: Optional[jnp.array] = None,
|
||||
key_value_states: Optional[jnp.ndarray] = None,
|
||||
init_cache: bool = False,
|
||||
deterministic=True,
|
||||
output_attentions: bool = False,
|
||||
|
|
|
|||
Loading…
Reference in a new issue