diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 97bccd2cb..bcfb6bfe5 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -113,10 +113,13 @@ class TFModelTesterMixin: tf_hidden_states = tfo[0].numpy() pt_hidden_states = pto[0].numpy() - pt_hidden_states[np.isnan(tf_hidden_states)] = 0 - tf_hidden_states[np.isnan(tf_hidden_states)] = 0 - pt_hidden_states[np.isnan(pt_hidden_states)] = 0 - tf_hidden_states[np.isnan(pt_hidden_states)] = 0 + tf_nans = np.copy(np.isnan(tf_hidden_states)) + pt_nans = np.copy(np.isnan(pt_hidden_states)) + + pt_hidden_states[tf_nans] = 0 + tf_hidden_states[tf_nans] = 0 + pt_hidden_states[pt_nans] = 0 + tf_hidden_states[pt_nans] = 0 max_diff = np.amax(np.abs(tf_hidden_states - pt_hidden_states)) # Debug info (remove when fixed) @@ -148,8 +151,14 @@ class TFModelTesterMixin: tfo = tf_model(inputs_dict) tfo = tfo[0].numpy() pto = pto[0].numpy() - tfo[np.isnan(tfo)] = 0 - pto[np.isnan(pto)] = 0 + tf_nans = np.copy(np.isnan(tfo)) + pt_nans = np.copy(np.isnan(pto)) + + pto[tf_nans] = 0 + tfo[tf_nans] = 0 + pto[pt_nans] = 0 + tfo[pt_nans] = 0 + max_diff = np.amax(np.abs(tfo - pto)) self.assertLessEqual(max_diff, 2e-2)