mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
remove allow-untyped-defs for torch/masked/maskedtensor/creation.py (#143321)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143321 Approved by: https://github.com/laithsakka
This commit is contained in:
parent
4d90c487d8
commit
cd7de1f4fa
1 changed files with 4 additions and 3 deletions
|
|
@ -1,4 +1,3 @@
|
|||
# mypy: allow-untyped-defs
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
|
||||
from .core import MaskedTensor
|
||||
|
|
@ -15,9 +14,11 @@ __all__ = [
|
|||
# torch.as_tensor - differentiable constructor that preserves the autograd history
|
||||
|
||||
|
||||
def masked_tensor(data, mask, requires_grad=False):
|
||||
def masked_tensor(
|
||||
data: object, mask: object, requires_grad: bool = False
|
||||
) -> MaskedTensor:
|
||||
return MaskedTensor(data, mask, requires_grad)
|
||||
|
||||
|
||||
def as_masked_tensor(data, mask):
|
||||
def as_masked_tensor(data: object, mask: object) -> MaskedTensor:
|
||||
return MaskedTensor._from_values(data, mask)
|
||||
|
|
|
|||
Loading…
Reference in a new issue