mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-23 22:13:38 +00:00
This PR implements DistributedExpand for llama 2. Representative Examples of DistributedExpand: - [shard on non-expanded axis] `input tensor (shape=[8, 1], spec=S[0]R, device_mesh=[0,1]) -> Expand(target_shape=[8, 2] -> output tensor (shape=[8, 2], spec=S[0]R, device_mesh=[0,1])` - [sharding expanded axis is invalid since it must have dim=1 and axis with dim=1 cannot be sharded] `input tensor (shape=[1, 8], spec=S[0]R, device_mesh=[0,1]) -> Expand(target_shape=[2, 8] -> output tensor (shape=[2, 8], spec=S[0]R, device_mesh=[0,1])` From those examples, we observe a few important behaviors. - The output sharding spec is always the same to the input sharding spec. - Expanding always happen on axis with dimension=1. Otherwise, it will violate the broadcasting rule. - No communication is needed since all computation can happen locally. Let's consider the first example again. If you put the first half tensor (shape: [4, 1]) on device 0 and the second half (shape: [4, 1]) on device 1, then `Expand` it with target shape [4, 2] , these two local tensors (shape: [4, 2]) are exactly the same as the one described by output sharding spec. Algorithm: - Compute logical (i.e., unsharded) shapes of input and output. - Compute sharded output shape from logical output. - Call Expand to broadcast local input to sharded output shape. How to review? - Start with [changes in onnxruntime_test_distributed.py]( |
||
|---|---|---|
| .. | ||
| external | ||
| patches | ||
| tensorboard | ||
| adjust_global_compile_flags.cmake | ||
| CMakeLists.txt | ||
| CMakeSettings.json | ||
| codeconv.runsettings | ||
| deps.txt | ||
| deps_update_and_upload.py | ||
| EnableVisualStudioCodeAnalysis.props | ||
| gdk_toolchain.cmake | ||
| Info.plist.in | ||
| libonnxruntime.pc.cmake.in | ||
| nuget_helpers.cmake | ||
| onnxruntime.cmake | ||
| onnxruntime_codegen_tvm.cmake | ||
| onnxruntime_common.cmake | ||
| onnxruntime_compile_triton_kernel.cmake | ||
| onnxruntime_config.h.in | ||
| onnxruntime_csharp.cmake | ||
| onnxruntime_flatbuffers.cmake | ||
| onnxruntime_framework.cmake | ||
| onnxruntime_framework.natvis | ||
| onnxruntime_fuzz_test.cmake | ||
| onnxruntime_graph.cmake | ||
| onnxruntime_ios.toolchain.cmake | ||
| onnxruntime_java.cmake | ||
| onnxruntime_java_unittests.cmake | ||
| onnxruntime_kernel_explorer.cmake | ||
| onnxruntime_language_interop_ops.cmake | ||
| onnxruntime_mlas.cmake | ||
| onnxruntime_nodejs.cmake | ||
| onnxruntime_objectivec.cmake | ||
| onnxruntime_opschema_lib.cmake | ||
| onnxruntime_optimizer.cmake | ||
| onnxruntime_providers.cmake | ||
| onnxruntime_providers_acl.cmake | ||
| onnxruntime_providers_armnn.cmake | ||
| onnxruntime_providers_azure.cmake | ||
| onnxruntime_providers_cann.cmake | ||
| onnxruntime_providers_coreml.cmake | ||
| onnxruntime_providers_cpu.cmake | ||
| onnxruntime_providers_cuda.cmake | ||
| onnxruntime_providers_dml.cmake | ||
| onnxruntime_providers_dnnl.cmake | ||
| onnxruntime_providers_js.cmake | ||
| onnxruntime_providers_migraphx.cmake | ||
| onnxruntime_providers_nnapi.cmake | ||
| onnxruntime_providers_openvino.cmake | ||
| onnxruntime_providers_qnn.cmake | ||
| onnxruntime_providers_rknpu.cmake | ||
| onnxruntime_providers_rocm.cmake | ||
| onnxruntime_providers_tensorrt.cmake | ||
| onnxruntime_providers_tvm.cmake | ||
| onnxruntime_providers_vitisai.cmake | ||
| onnxruntime_providers_webnn.cmake | ||
| onnxruntime_providers_xnnpack.cmake | ||
| onnxruntime_pyop.cmake | ||
| onnxruntime_python.cmake | ||
| onnxruntime_rocm_hipify.cmake | ||
| onnxruntime_session.cmake | ||
| onnxruntime_snpe_provider.cmake | ||
| onnxruntime_training.cmake | ||
| onnxruntime_unittests.cmake | ||
| onnxruntime_util.cmake | ||
| onnxruntime_webassembly.cmake | ||
| precompiled_header.cmake | ||
| Sdl.ruleset | ||
| set_winapi_family_desktop.h | ||
| target_delayload.cmake | ||
| uwp_stubs.h | ||
| wcos_rules_override.cmake | ||
| winml.cmake | ||
| winml_cppwinrt.cmake | ||
| winml_sdk_helpers.cmake | ||
| winml_unittests.cmake | ||