Fix for CPU random ops seed narrowing conversion. (#1594)

This commit is contained in:
AlbertSadovnikov 2019-08-12 09:01:13 -07:00 committed by Changming Sun
parent df9b1b8ec8
commit ce3c8f98dd

View file

@ -20,11 +20,14 @@ class RandomNormal final : public OpKernel {
// read optional seed attribute and generate if not provided
float seed = 0.f;
if (!info.GetAttr<float>("seed", &seed).IsOK()) {
seed = gsl::narrow_cast<float>(std::chrono::high_resolution_clock::now().time_since_epoch().count());
if (info.GetAttr<float>("seed", &seed).IsOK()) {
generator_ = std::default_random_engine{gsl::narrow_cast<uint32_t>(seed)};
}
else {
generator_ = std::default_random_engine{
gsl::narrow_cast<uint32_t>(std::chrono::high_resolution_clock::now().time_since_epoch().count())
};
}
generator_ = std::default_random_engine{gsl::narrow_cast<uint32_t>(seed)};
int64_t dtype;
ORT_ENFORCE(info.GetAttr<int64_t>("dtype", &dtype).IsOK());
@ -60,11 +63,14 @@ class RandomNormalLike final : public OpKernel {
// read optional seed attribute and generate if not provided
float seed = 0.f;
if (!info.GetAttr<float>("seed", &seed).IsOK()) {
seed = gsl::narrow_cast<float>(std::chrono::high_resolution_clock::now().time_since_epoch().count());
if (info.GetAttr<float>("seed", &seed).IsOK()) {
generator_ = std::default_random_engine{gsl::narrow_cast<uint32_t>(seed)};
}
else {
generator_ = std::default_random_engine{
gsl::narrow_cast<uint32_t>(std::chrono::high_resolution_clock::now().time_since_epoch().count())
};
}
generator_ = std::default_random_engine{gsl::narrow_cast<uint32_t>(seed)};
int64_t dtype;
if (info.GetAttr<int64_t>("dtype", &dtype).IsOK()) {
@ -94,11 +100,14 @@ class RandomUniform final : public OpKernel {
// read optional seed attribute and generate if not provided
float seed = 0.f;
if (!info.GetAttr<float>("seed", &seed).IsOK()) {
seed = gsl::narrow_cast<float>(std::chrono::high_resolution_clock::now().time_since_epoch().count());
if (info.GetAttr<float>("seed", &seed).IsOK()) {
generator_ = std::default_random_engine{gsl::narrow_cast<uint32_t>(seed)};
}
else {
generator_ = std::default_random_engine{
gsl::narrow_cast<uint32_t>(std::chrono::high_resolution_clock::now().time_since_epoch().count())
};
}
generator_ = std::default_random_engine{gsl::narrow_cast<uint32_t>(seed)};
int64_t dtype;
ORT_ENFORCE(info.GetAttr<int64_t>("dtype", &dtype).IsOK());
@ -131,11 +140,14 @@ class RandomUniformLike final : public OpKernel {
ORT_ENFORCE(info.GetAttr<float>("low", &low_).IsOK());
// read optional seed attribute and generate if not provided
float seed = 0.f;
if (!info.GetAttr<float>("seed", &seed).IsOK()) {
seed = gsl::narrow_cast<float>(std::chrono::high_resolution_clock::now().time_since_epoch().count());
if (info.GetAttr<float>("seed", &seed).IsOK()) {
generator_ = std::default_random_engine{gsl::narrow_cast<uint32_t>(seed)};
}
else {
generator_ = std::default_random_engine{
gsl::narrow_cast<uint32_t>(std::chrono::high_resolution_clock::now().time_since_epoch().count())
};
}
generator_ = std::default_random_engine{gsl::narrow_cast<uint32_t>(seed)};
int64_t dtype;
if (info.GetAttr<int64_t>("dtype", &dtype).IsOK()) {
@ -163,11 +175,14 @@ class Multinomial final : public OpKernel {
ORT_ENFORCE(info.GetAttr<int64_t>("sample_size", &num_samples_).IsOK());
float seed = 0.f;
if (!info.GetAttr<float>("seed", &seed).IsOK()) {
seed = gsl::narrow_cast<float>(std::chrono::high_resolution_clock::now().time_since_epoch().count());
if (info.GetAttr<float>("seed", &seed).IsOK()) {
generator_ = std::default_random_engine{gsl::narrow_cast<uint32_t>(seed)};
}
else {
generator_ = std::default_random_engine{
gsl::narrow_cast<uint32_t>(std::chrono::high_resolution_clock::now().time_since_epoch().count())
};
}
generator_ = std::default_random_engine{gsl::narrow_cast<uint32_t>(seed)};
int64_t output_dtype_tmp;
if (!info.GetAttr<int64_t>("dtype", &output_dtype_tmp).IsOK()) {