mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-06 00:03:22 +00:00
Fix for CPU random ops seed narrowing conversion. (#1594)
This commit is contained in:
parent
df9b1b8ec8
commit
ce3c8f98dd
1 changed files with 35 additions and 20 deletions
|
|
@ -20,11 +20,14 @@ class RandomNormal final : public OpKernel {
|
|||
|
||||
// read optional seed attribute and generate if not provided
|
||||
float seed = 0.f;
|
||||
if (!info.GetAttr<float>("seed", &seed).IsOK()) {
|
||||
seed = gsl::narrow_cast<float>(std::chrono::high_resolution_clock::now().time_since_epoch().count());
|
||||
if (info.GetAttr<float>("seed", &seed).IsOK()) {
|
||||
generator_ = std::default_random_engine{gsl::narrow_cast<uint32_t>(seed)};
|
||||
}
|
||||
else {
|
||||
generator_ = std::default_random_engine{
|
||||
gsl::narrow_cast<uint32_t>(std::chrono::high_resolution_clock::now().time_since_epoch().count())
|
||||
};
|
||||
}
|
||||
|
||||
generator_ = std::default_random_engine{gsl::narrow_cast<uint32_t>(seed)};
|
||||
|
||||
int64_t dtype;
|
||||
ORT_ENFORCE(info.GetAttr<int64_t>("dtype", &dtype).IsOK());
|
||||
|
|
@ -60,11 +63,14 @@ class RandomNormalLike final : public OpKernel {
|
|||
|
||||
// read optional seed attribute and generate if not provided
|
||||
float seed = 0.f;
|
||||
if (!info.GetAttr<float>("seed", &seed).IsOK()) {
|
||||
seed = gsl::narrow_cast<float>(std::chrono::high_resolution_clock::now().time_since_epoch().count());
|
||||
if (info.GetAttr<float>("seed", &seed).IsOK()) {
|
||||
generator_ = std::default_random_engine{gsl::narrow_cast<uint32_t>(seed)};
|
||||
}
|
||||
else {
|
||||
generator_ = std::default_random_engine{
|
||||
gsl::narrow_cast<uint32_t>(std::chrono::high_resolution_clock::now().time_since_epoch().count())
|
||||
};
|
||||
}
|
||||
|
||||
generator_ = std::default_random_engine{gsl::narrow_cast<uint32_t>(seed)};
|
||||
|
||||
int64_t dtype;
|
||||
if (info.GetAttr<int64_t>("dtype", &dtype).IsOK()) {
|
||||
|
|
@ -94,11 +100,14 @@ class RandomUniform final : public OpKernel {
|
|||
|
||||
// read optional seed attribute and generate if not provided
|
||||
float seed = 0.f;
|
||||
if (!info.GetAttr<float>("seed", &seed).IsOK()) {
|
||||
seed = gsl::narrow_cast<float>(std::chrono::high_resolution_clock::now().time_since_epoch().count());
|
||||
if (info.GetAttr<float>("seed", &seed).IsOK()) {
|
||||
generator_ = std::default_random_engine{gsl::narrow_cast<uint32_t>(seed)};
|
||||
}
|
||||
else {
|
||||
generator_ = std::default_random_engine{
|
||||
gsl::narrow_cast<uint32_t>(std::chrono::high_resolution_clock::now().time_since_epoch().count())
|
||||
};
|
||||
}
|
||||
|
||||
generator_ = std::default_random_engine{gsl::narrow_cast<uint32_t>(seed)};
|
||||
|
||||
int64_t dtype;
|
||||
ORT_ENFORCE(info.GetAttr<int64_t>("dtype", &dtype).IsOK());
|
||||
|
|
@ -131,11 +140,14 @@ class RandomUniformLike final : public OpKernel {
|
|||
ORT_ENFORCE(info.GetAttr<float>("low", &low_).IsOK());
|
||||
// read optional seed attribute and generate if not provided
|
||||
float seed = 0.f;
|
||||
if (!info.GetAttr<float>("seed", &seed).IsOK()) {
|
||||
seed = gsl::narrow_cast<float>(std::chrono::high_resolution_clock::now().time_since_epoch().count());
|
||||
if (info.GetAttr<float>("seed", &seed).IsOK()) {
|
||||
generator_ = std::default_random_engine{gsl::narrow_cast<uint32_t>(seed)};
|
||||
}
|
||||
else {
|
||||
generator_ = std::default_random_engine{
|
||||
gsl::narrow_cast<uint32_t>(std::chrono::high_resolution_clock::now().time_since_epoch().count())
|
||||
};
|
||||
}
|
||||
|
||||
generator_ = std::default_random_engine{gsl::narrow_cast<uint32_t>(seed)};
|
||||
|
||||
int64_t dtype;
|
||||
if (info.GetAttr<int64_t>("dtype", &dtype).IsOK()) {
|
||||
|
|
@ -163,11 +175,14 @@ class Multinomial final : public OpKernel {
|
|||
ORT_ENFORCE(info.GetAttr<int64_t>("sample_size", &num_samples_).IsOK());
|
||||
|
||||
float seed = 0.f;
|
||||
if (!info.GetAttr<float>("seed", &seed).IsOK()) {
|
||||
seed = gsl::narrow_cast<float>(std::chrono::high_resolution_clock::now().time_since_epoch().count());
|
||||
if (info.GetAttr<float>("seed", &seed).IsOK()) {
|
||||
generator_ = std::default_random_engine{gsl::narrow_cast<uint32_t>(seed)};
|
||||
}
|
||||
else {
|
||||
generator_ = std::default_random_engine{
|
||||
gsl::narrow_cast<uint32_t>(std::chrono::high_resolution_clock::now().time_since_epoch().count())
|
||||
};
|
||||
}
|
||||
|
||||
generator_ = std::default_random_engine{gsl::narrow_cast<uint32_t>(seed)};
|
||||
|
||||
int64_t output_dtype_tmp;
|
||||
if (!info.GetAttr<int64_t>("dtype", &output_dtype_tmp).IsOK()) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue