Update test_modeling_common.py

This commit is contained in:
Cyril Vallez 2025-02-07 01:54:55 +01:00
parent 42f02f6ee1
commit a17393187f
No known key found for this signature in database

View file

@ -2339,7 +2339,7 @@ class ModelTesterMixin:
tied_params = [names for _, names in ptrs.items() if len(names) > 1]
for group in tied_params:
# We remove the group from extra_missing if not all weights from group are in it
if len(group - extra_missing) > 0:
if len(set(group) - extra_missing) > 0:
extra_missing = extra_missing - set(group)
self.assertEqual(