mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-23 22:20:18 +00:00
20 lines
416 B
Python
20 lines
416 B
Python
import random
|
|
|
|
import torch as th
|
|
import numpy as np
|
|
|
|
|
|
def set_random_seed(seed, using_cuda=False):
|
|
"""
|
|
Seed the different random generators
|
|
:param seed: (int)
|
|
:param using_cuda: (bool)
|
|
"""
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
th.manual_seed(seed)
|
|
|
|
if using_cuda:
|
|
# Make CuDNN Determinist
|
|
th.backends.cudnn.deterministic = True
|
|
th.cuda.manual_seed(seed)
|