mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Add post_process_depth_estimation for GLPN (#34413)
* add depth postprocessing for GLPN * remove previous temp fix for glpn tests * Style changes for GLPN's `post_process_depth_estimation` Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * additional style fix --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
parent
6cc4a67b3d
commit
a769ed45e1
3 changed files with 59 additions and 19 deletions
|
|
@ -14,7 +14,11 @@
|
|||
# limitations under the License.
|
||||
"""Image processor class for GLPN."""
|
||||
|
||||
from typing import List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...modeling_outputs import DepthEstimatorOutput
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
|
|
@ -27,12 +31,17 @@ from ...image_utils import (
|
|||
get_image_size,
|
||||
infer_channel_dimension_format,
|
||||
is_scaled_image,
|
||||
is_torch_available,
|
||||
make_list_of_images,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
validate_preprocess_arguments,
|
||||
)
|
||||
from ...utils import TensorType, filter_out_non_signature_kwargs, logging
|
||||
from ...utils import TensorType, filter_out_non_signature_kwargs, logging, requires_backends
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
|
@ -218,3 +227,44 @@ class GLPNImageProcessor(BaseImageProcessor):
|
|||
|
||||
data = {"pixel_values": images}
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
||||
def post_process_depth_estimation(
|
||||
self,
|
||||
outputs: "DepthEstimatorOutput",
|
||||
target_sizes: Optional[Union[TensorType, List[Tuple[int, int]], None]] = None,
|
||||
) -> List[Dict[str, TensorType]]:
|
||||
"""
|
||||
Converts the raw output of [`DepthEstimatorOutput`] into final depth predictions and depth PIL images.
|
||||
Only supports PyTorch.
|
||||
|
||||
Args:
|
||||
outputs ([`DepthEstimatorOutput`]):
|
||||
Raw outputs of the model.
|
||||
target_sizes (`TensorType` or `List[Tuple[int, int]]`, *optional*):
|
||||
Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
|
||||
(height, width) of each image in the batch. If left to None, predictions will not be resized.
|
||||
|
||||
Returns:
|
||||
`List[Dict[str, TensorType]]`: A list of dictionaries of tensors representing the processed depth
|
||||
predictions.
|
||||
"""
|
||||
requires_backends(self, "torch")
|
||||
|
||||
predicted_depth = outputs.predicted_depth
|
||||
|
||||
if (target_sizes is not None) and (len(predicted_depth) != len(target_sizes)):
|
||||
raise ValueError(
|
||||
"Make sure that you pass in as many target sizes as the batch dimension of the predicted depth"
|
||||
)
|
||||
|
||||
results = []
|
||||
target_sizes = [None] * len(predicted_depth) if target_sizes is None else target_sizes
|
||||
for depth, target_size in zip(predicted_depth, target_sizes):
|
||||
if target_size is not None:
|
||||
depth = depth[None, None, ...]
|
||||
depth = torch.nn.functional.interpolate(depth, size=target_size, mode="bicubic", align_corners=False)
|
||||
depth = depth.squeeze()
|
||||
|
||||
results.append({"predicted_depth": depth})
|
||||
|
||||
return results
|
||||
|
|
|
|||
|
|
@ -723,20 +723,18 @@ class GLPNForDepthEstimation(GLPNPreTrainedModel):
|
|||
|
||||
>>> with torch.no_grad():
|
||||
... outputs = model(**inputs)
|
||||
... predicted_depth = outputs.predicted_depth
|
||||
|
||||
>>> # interpolate to original size
|
||||
>>> prediction = torch.nn.functional.interpolate(
|
||||
... predicted_depth.unsqueeze(1),
|
||||
... size=image.size[::-1],
|
||||
... mode="bicubic",
|
||||
... align_corners=False,
|
||||
>>> post_processed_output = image_processor.post_process_depth_estimation(
|
||||
... outputs,
|
||||
... target_sizes=[(image.height, image.width)],
|
||||
... )
|
||||
|
||||
>>> # visualize the prediction
|
||||
>>> output = prediction.squeeze().cpu().numpy()
|
||||
>>> formatted = (output * 255 / np.max(output)).astype("uint8")
|
||||
>>> depth = Image.fromarray(formatted)
|
||||
>>> predicted_depth = post_processed_output[0]["predicted_depth"]
|
||||
>>> depth = predicted_depth * 255 / predicted_depth.max()
|
||||
>>> depth = depth.detach().cpu().numpy()
|
||||
>>> depth = Image.fromarray(depth.astype("uint8"))
|
||||
```"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
output_hidden_states = (
|
||||
|
|
|
|||
|
|
@ -157,14 +157,6 @@ class GLPNModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|||
self.model_tester = GLPNModelTester(self)
|
||||
self.config_tester = GLPNConfigTester(self, config_class=GLPNConfig)
|
||||
|
||||
@unittest.skip(reason="Failing after #32550")
|
||||
def test_pipeline_depth_estimation(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Failing after #32550")
|
||||
def test_pipeline_depth_estimation_fp16(self):
|
||||
pass
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue