mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Added Type hints for LED TF (#19315)
* Update modeling_tf_led.py * Update modeling_tf_led.py
This commit is contained in:
parent
3a1a56a8fe
commit
ac5ea74ee8
1 changed files with 18 additions and 16 deletions
|
|
@ -19,6 +19,7 @@ import random
|
|||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from ...activations_tf import get_tf_activation
|
||||
|
|
@ -26,6 +27,7 @@ from ...modeling_tf_outputs import TFBaseModelOutputWithPastAndCrossAttentions
|
|||
|
||||
# Public API
|
||||
from ...modeling_tf_utils import (
|
||||
TFModelInputType,
|
||||
TFPreTrainedModel,
|
||||
TFSharedEmbeddings,
|
||||
TFWrappedEmbeddings,
|
||||
|
|
@ -2390,23 +2392,23 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
|
|||
@replace_return_docstrings(output_type=TFLEDSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
input_ids: Optional[TFModelInputType] = None,
|
||||
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
decoder_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
decoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
decoder_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
encoder_outputs: Optional[TFLEDEncoderBaseModelOutput] = None,
|
||||
global_attention_mask=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
decoder_inputs_embeds=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
global_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
|
||||
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
decoder_inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
labels: Optional[tf.Tensor] = None,
|
||||
training: bool = False,
|
||||
):
|
||||
"""
|
||||
Returns:
|
||||
|
|
|
|||
Loading…
Reference in a new issue