onnxruntime/orttraining
Suffian Khan 84589c7e05
Fuse softmax(a + b) in case of simple broadcast (#4937)
* bias softmax kernel

* bias softmax kernel

* remove debug comments

* remove debug comment

* windows build doesnt handle unary minus on unsigned type

* int64 => int treated as error

* only support cuda

* add bias softmax fusion tests

* PR comments

* more PR comments

* use MLTypeCallDispatcher

* break function into pieces

* add loop unroll and add to list for inference as well

* use std::min and move operator==

* revert std::min (doesnt work ci pipeline) and fix int to size_t error

* pr comments

* fixes for windows ci

* fix for windows ci

* pr comments on consistency

* p_model_

* fix formatting and add anonymous namespace

Co-authored-by: suffian khan <sukha@OrtTrainingDev1.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
2020-09-18 14:15:55 -07:00
..
orttraining Fuse softmax(a + b) in case of simple broadcast (#4937) 2020-09-18 14:15:55 -07:00
pytorch_frontend_examples Fix mnist example (#4926) 2020-08-26 15:28:39 -07:00
tools Update convergence baseline for ci_test. (#4465) 2020-07-09 15:29:36 +08:00