diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index e54cd5d8e..2e2db5b11 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2339,7 +2339,7 @@ class ModelTesterMixin: tied_params = [names for _, names in ptrs.items() if len(names) > 1] for group in tied_params: # We remove the group from extra_missing if not all weights from group are in it - if len(group - extra_missing) > 0: + if len(set(group) - extra_missing) > 0: extra_missing = extra_missing - set(group) self.assertEqual(