mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Watermark: fix tests (#30961)
* fix tests * style * Update tests/generation/test_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
parent
a3c7b59e31
commit
779bc360ff
1 changed files with 3 additions and 9 deletions
|
|
@ -2148,6 +2148,8 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
|||
watermark_config = WatermarkingConfig(bias=2.5, seeding_scheme="selfhash")
|
||||
_ = model.generate(**model_inputs, watermarking_config=watermark_config, do_sample=False, max_length=15)
|
||||
|
||||
# We will not check watermarked text, since we check it in `logits_processors` tests
|
||||
# Checking if generated ids are as expected fails on different hardware
|
||||
args = {
|
||||
"bias": 2.0,
|
||||
"context_width": 1,
|
||||
|
|
@ -2158,19 +2160,11 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
|||
output = model.generate(**model_inputs, do_sample=False, max_length=15)
|
||||
output_selfhash = model.generate(**model_inputs, watermarking_config=args, do_sample=False, max_length=15)
|
||||
|
||||
# check that the watermarked text is generating what is should
|
||||
self.assertListEqual(
|
||||
output.tolist(), [[40, 481, 307, 262, 717, 284, 9159, 326, 314, 716, 407, 257, 4336, 286, 262]]
|
||||
)
|
||||
self.assertListEqual(
|
||||
output_selfhash.tolist(), [[40, 481, 307, 2263, 616, 640, 284, 651, 616, 1621, 503, 612, 553, 531, 367]]
|
||||
)
|
||||
|
||||
# Check that the detector is detecting watermarked text
|
||||
detector = WatermarkDetector(model_config=model.config, device=torch_device, watermarking_config=args)
|
||||
detection_out_watermarked = detector(output_selfhash[:, input_len:], return_dict=True)
|
||||
detection_out = detector(output[:, input_len:], return_dict=True)
|
||||
|
||||
# check that the detector is detecting watermarked text
|
||||
self.assertListEqual(detection_out_watermarked.prediction.tolist(), [True])
|
||||
self.assertListEqual(detection_out.prediction.tolist(), [False])
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue