* disable materialize grads
* gradient builder bugfix
* fix ut
* fix ut
* resolve comments and bugfix
* add more assert
* disable forward compare for now
* priority-based exec order
* disable 1 failing test
* fix UT
* more comments
Co-authored-by: Ethan Tao <ettao@OrtTrainingDev4.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
* Add more asserts on forward outputs
* Found one more failing case
Co-authored-by: Sherlock Huang <bahuang@OrtTrainingDev3.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
* Added required_grad attribute to YieldOp
* Chagened YieldOp attribute to hold the indices of the required gradient outputs from the count, and removed the code reordering the outputs.
* Changed backward_output_grad_names to a map from backward output gradient name to the corresponding output index.
* Introduce OrtTasks to replace EventPool
* return run_id to frontend
* pass run_id to backward
* OrtTasks support multiple bg_events
* make message_queue a member of orttask
* Replace MessageQueue with std::promise
* Move status_promise into Task
* Move terminate flag into Task
* Reenable previously disabled UTs
* Add unit tests
* Replace condition variables with std::promise
* Move to CreateBackgroundTask in the main thread
* return status and output in forward_future
* use throw for terminating background thread
* cleanup tasks at destructor
* reenable test_mixed_nnmodule_ortmodules_training
* add mutex for ORTTasks functions
* add mutex for bg_threads
* delay tests before start
* add ut for multi-task common backbone
Co-authored-by: Sherlock Huang <bahuang@OrtTrainingDev3.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
* remove tests to speed up CI
* add back _into_data_parallelism tests to see how long the CI test takes
* remove unnecessary save calls
* add back data_parallelism_full_precision_bart_path
* add data_parallelism_full_precision_path
* remove data parallelism tests
Co-authored-by: Jingyan Wang <jingywa@OrtTrainingDev3.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
* Update torchtext usage for pytorch transformer sample
* Temporarily disable tests to unblock repo (failures are being worked on already)
* Update loss numbers for ORTTrainer UTs
* Support keyword arguments for ORTModule.
* Add backward workaround to the test.
* Specify test name directly without -k.
* Handle unused inputs removed by ONNX exporter.
* Enable external CUDA allocator in ORTModule.
* Fix assert after unification of allocators.
* Update no grad memory test.
* update comments.
* fix provider options array when not sharing allocator.
* Partial updating of ROCM reduction code.
* Update reduction_all.cu
* Add reduce template parameters.
* miopen common
* Reuse CUDA's reduction_functions.cc
* Reduction ops.
* Update remaining reduction ops to use MIOpen. double datatype is not supported, so disable those typed kernels.
* Disable a couple more unsupported tests.
* Code formatting.
* Delete ROCM-specific reduction code that is identical to CUDA reduction code.
* Fix scratch buffer early free.
* Fix merge conflict.
* first attempt nightly amd ci pipeline
* try fix bad yaml file
* try again with corrected model directory
* add convergence test as well
* update reference loss for amd mi100
* include mi100 test results csv
* update the mi100 convergence test reference values
* update batch sizes for mi100 32g
* fix gpu sku for run_convergence_test.py
* undo unrelated changes to master
* pr comments
* pr comment
Co-authored-by: Jesse Benson <jesseb@microsoft.com>
* ortmodule v0.2
* use pt module for eval
* get user outputs in yield op
* pass output grads to yield output without copy
* Disable mem_pattern for ORTModule
* Avoid allocating output buffer for Yield op
* Change to WaitAndReset to avoid overriding signal
* remove unnecessory signal/wait at the end of bg thread
* Return Session.Run result as a std::future
* export model with torch.no_grad()
* Handle bg thread's early return in Forward call
* Removed duplicated Yield kernel
* Silence "CUDA kernel missing log"
* Add missing transforms, clear iobinding (#6532)
* revert ortmodule.py to a working state first
* Apply ortmodule.py change from dev branch
* Rename to YieldOp
Co-authored-by: Sherlock Huang <bahuang@OrtTrainingDev3.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
Co-authored-by: ashbhandare <ash.bhandare@gmail.com>
Co-authored-by: Sherlock <baihan.huang@gmail.com>
* Add warning when GetGradientForOp() silently fails.
In some cases, `GetGradientForOp()` can return without creating any nodes, which may lead to an invalid graph being created.
Remove condition from ORT_RETURN_IF[_NOT] macro output as repeating the condition doesn't add much value compared to the explicit error message, and the error message includes the file and line anyway so it's easy enough to find the condition if needed.
Update the few places where the macros were used without an explicit error message to provide an explicit error message.
Saves 12.5KB in a minimal MinSizeRel build with all DNN ops, 16KB in full release build.
* Support to allow user to specify compute stream per session
Create computation cuda stream explicitly rather than use default legacy stream or per-thread default stream.
remove some redudant cudaStreamSynchronize
fix gpt2 model test failures
don't use default stream in nccl either.
add stream schronization in OnRunEnd()
using cub::DeviceScan::InclusiveSum which can be called with stream specified.
fix topK failure due to latest rebase
fix tensorrt
support user specified stream
add user_stream support in tensorrt EP
use same stream for both tensort and CUDA EP.
fix ScatterND
specify stream for adasum and p2p kernels.
fix loop
fix CApiTest.custom_op_handler
fix CApiTest.varied_input_custom_op_handler
change for cudaMemcpyFromSymbol
improve provider options for user specified compute stream
* add changes for ROCM EP
* fix GatherGrad UT for ROCM EP
* clean code and fix NonMaxSuppression
* use default stream for ROCM now
* fix CApiTest.custom_op_handler:OrtFormatCustomOpTests.ConvertOnnxModelToOrt
* fix tensorrt ut: CApiTest.io_binding_cuda
Co-authored-by: Weixing Zhang <wezhan@microsoft.com>