#pragma kernel BitonicSort128 BITONIC_SORT=BitonicSort128 ELEMENTS_COUNT=128 ITERATIONS_COUNT=1 FINAL_PASS=1 #pragma kernel BitonicSort1024 BITONIC_SORT=BitonicSort1024 ELEMENTS_COUNT=1024 ITERATIONS_COUNT=2 FINAL_PASS=1 #pragma kernel BitonicSort1024_128 BITONIC_SORT=BitonicSort1024_128 ELEMENTS_COUNT=1024 ITERATIONS_COUNT=4 FINAL_PASS=1 #pragma kernel BitonicSort2048 BITONIC_SORT=BitonicSort2048 ELEMENTS_COUNT=2048 ITERATIONS_COUNT=2 FINAL_PASS=1 #pragma kernel BitonicSort2048_128 BITONIC_SORT=BitonicSort2048_128 ELEMENTS_COUNT=2048 ITERATIONS_COUNT=8 FINAL_PASS=1 #pragma kernel BitonicSort4096 BITONIC_SORT=BitonicSort4096 ELEMENTS_COUNT=4096 ITERATIONS_COUNT=2 FINAL_PASS=1 #pragma kernel BitonicSort4096_128 BITONIC_SORT=BitonicSort4096_128 ELEMENTS_COUNT=4096 ITERATIONS_COUNT=16 FINAL_PASS=1 #pragma kernel BitonicPrePass BITONIC_SORT=BitonicPrePass ELEMENTS_COUNT=4096 ITERATIONS_COUNT=2 FINAL_PASS=0 #pragma kernel BitonicPrePass4096_128 BITONIC_SORT=BitonicPrePass4096_128 ELEMENTS_COUNT=4096 ITERATIONS_COUNT=16 FINAL_PASS=0 #pragma kernel BitonicPrePass2048_128 BITONIC_SORT=BitonicPrePass2048_128 ELEMENTS_COUNT=2048 ITERATIONS_COUNT=8 FINAL_PASS=0 #pragma kernel MergePass MERGE_PASS=MergePass FINAL_PASS=0 #pragma kernel MergeFinalPass MERGE_PASS=MergeFinalPass FINAL_PASS=1 #pragma only_renderers d3d11 playstation xboxone xboxseries vulkan metal switch glcore gles3 webgpu #include "HLSLSupport.cginc" #define DECREASING_ORDER 1 #define DEBUG_NO_INFINITE_LOOP 0 #ifndef ELEMENTS_COUNT #define ELEMENTS_COUNT 1024 #endif #ifndef ITERATIONS_COUNT #define ITERATIONS_COUNT 1 #endif // 1 to use the alternative representation of the bitonic network: // No comparison flipping but dedicated first sub pass. // see https://en.wikipedia.org/wiki/Bitonic_sorter #if defined(SHADER_API_METAL) //Workaround BitonicSort128 doesn't behave as expected, actual cause isn't identified yet. #define USE_ALTERNATE_BITONIC_NETWORK 1 #else #define USE_ALTERNATE_BITONIC_NETWORK 0 #endif #define ELEMENTS_PER_THREAD (2 * ITERATIONS_COUNT) #define BITONIC_THREADS_COUNT ELEMENTS_COUNT / ELEMENTS_PER_THREAD #define MERGE_THREADS_COUNT ELEMENTS_COUNT #define HALF_SCRATCH_SIZE ELEMENTS_COUNT #define LDS_VALUES_OFFSET 1 #define SCRATCH_SIZE (HALF_SCRATCH_SIZE << 1) #define ITERATION_INDEX(id,it) ((it) * BITONIC_THREADS_COUNT + (id)) #define INSTANCE_DST_INDEX(groupId,threadId) (groupId.x * BITONIC_THREADS_COUNT * ELEMENTS_PER_THREAD + threadId) #define DST_INDEX(groupId,threadId) (INSTANCE_DST_INDEX(groupId,threadId) + groupId.z * instanceOffset) #if DECREASING_ORDER #define REJECTED_VALUE asfloat(0xff7fffff) // -MAX_FLOAT #else #define REJECTED_VALUE asfloat(0x7f7fffff) // MAX_FLOAT #endif #pragma warning(disable : 3557) // disable warning for auto unrolling of single iteration loop struct KVP { float key; uint value; }; StructuredBuffer inputSequence; #if FINAL_PASS RWStructuredBuffer sortedSequence; #else RWStructuredBuffer sortedSequence; #endif // Layout of the scratch memory is as follows: // {Key,Value},{Key,Value}, ... groupshared float scratch[SCRATCH_SIZE]; CBUFFER_START(SortConstants) uint instanceOffset; uint totalInstanceCount; CBUFFER_END uint GetElementCount(uint instanceIndex) { return inputSequence[instanceIndex].value; } bool CompareKVP(KVP kvp0, KVP kvp1) { if (kvp0.key == kvp1.key) return kvp0.value < kvp1.value; // value always in ascending order else #if DECREASING_ORDER return kvp0.key > kvp1.key; #else return kvp0.key < kvp1.key; #endif } uint GetLDSIndex(uint index) { return 2 * index; } void LoadFromMemory(uint3 groupId, uint ldsIndex) { uint instanceDstIndex = INSTANCE_DST_INDEX(groupId, ldsIndex); uint memIndex = DST_INDEX(groupId, ldsIndex); KVP kvp = { REJECTED_VALUE, 0xFFFFFFFF }; if (instanceDstIndex < GetElementCount(groupId.z)) kvp = inputSequence[totalInstanceCount + memIndex]; uint paddedLdsIndex = GetLDSIndex(ldsIndex); scratch[paddedLdsIndex] = kvp.key; scratch[paddedLdsIndex + LDS_VALUES_OFFSET] = asfloat(kvp.value); } void StoreToMemory(uint3 groupId, uint ldsIndex) { uint instanceDstIndex = INSTANCE_DST_INDEX(groupId, ldsIndex); uint memIndex = DST_INDEX(groupId, ldsIndex); if (instanceDstIndex < GetElementCount(groupId.z)) { uint paddedLdsIndex = GetLDSIndex(ldsIndex); uint value = asuint(scratch[paddedLdsIndex + LDS_VALUES_OFFSET]); #if FINAL_PASS sortedSequence[totalInstanceCount + memIndex] = value; #else float key = scratch[paddedLdsIndex]; KVP kvp = { key, value }; sortedSequence[totalInstanceCount + memIndex] = kvp; #endif } } // Bitonic sort on small chunks of kvp of size ELEMENTS_COUNT - execute in O(logĀ²(ELEMENTS_COUNT)) [numthreads(BITONIC_THREADS_COUNT, 1, 1)] void BITONIC_SORT(uint id : SV_GroupIndex, uint3 groupId : SV_GroupID) { // Skip useless groups if (groupId.x > GetElementCount(groupId.z) / ELEMENTS_COUNT) return; if (groupId.x == 0 && id == 0) #if FINAL_PASS sortedSequence[groupId.z] = inputSequence[groupId.z].value; #else sortedSequence[groupId.z] = inputSequence[groupId.z]; #endif // Load data from memory to LDS //[unroll] for (uint i = 0; i < ELEMENTS_PER_THREAD; ++i) { uint index = BITONIC_THREADS_COUNT * i + id; LoadFromMemory(groupId,index); } GroupMemoryBarrierWithGroupSync(); // LDS Writes visible for (uint step = 1; step < ELEMENTS_COUNT; step <<= 1) // O(log(ELEMENTS_COUNT)) for (uint subStep = step; subStep != 0; subStep >>= 1) // O(log(step)) { [unroll] for (uint i = 0; i < ITERATIONS_COUNT; ++i) { uint index = ITERATION_INDEX(id, i); uint lsb = index & (subStep - 1); uint index0 = (2 * index) - lsb; uint index1 = index0 + subStep; #if USE_ALTERNATE_BITONIC_NETWORK if (subStep == step) index1 += step - (2 * lsb) - 1; #endif uint ldsIndex0 = GetLDSIndex(index0); uint ldsIndex1 = GetLDSIndex(index1); float key0 = scratch[ldsIndex0]; float value0 = scratch[ldsIndex0 + LDS_VALUES_OFFSET]; float key1 = scratch[ldsIndex1]; float value1 = scratch[ldsIndex1 + LDS_VALUES_OFFSET]; KVP kvp0 = { key0, asuint(value0) }; KVP kvp1 = { key1, asuint(value1) }; #if USE_ALTERNATE_BITONIC_NETWORK bool reverse = false; #else bool reverse = index & step; #endif if (CompareKVP(kvp1, kvp0) != reverse) { // swap first scratch[ldsIndex0] = key1; scratch[ldsIndex0 + LDS_VALUES_OFFSET] = value1; // swap second scratch[ldsIndex1] = key0; scratch[ldsIndex1 + LDS_VALUES_OFFSET] = value0; } } GroupMemoryBarrierWithGroupSync(); // LDS Writes visible } // Store sorted data from LDS to memory //[unroll] for (uint j = 0; j < ELEMENTS_PER_THREAD; ++j) { uint index = BITONIC_THREADS_COUNT * j + id; //(id / 32) * 64 * ITERATIONS_COUNT + (id & 31) + j * 32;//id * ITERATIONS_COUNT * 2 + j; StoreToMemory(groupId, index); } } CBUFFER_START(MergePassConstants) uint subArraySize; uint dispatchWidth; CBUFFER_END // Merge pass: take N sorted sub arrays as input and output N/2 sorted arrays twice bigger - execute in O(log(subArraySize)) #define NB_THREADS_PER_GROUP_MERGE_PASS 64 [numthreads(NB_THREADS_PER_GROUP_MERGE_PASS, 1, 1)] void MERGE_PASS(uint3 groupId : SV_GroupID, uint3 groupThreadId : SV_GroupThreadID) { uint id = groupThreadId.x + groupId.x * NB_THREADS_PER_GROUP_MERGE_PASS + groupId.y * dispatchWidth * NB_THREADS_PER_GROUP_MERGE_PASS; #if FINAL_PASS if (groupId.x == 0 && id == 0) sortedSequence[groupId.z] = inputSequence[groupId.z].value; #endif if (id >= GetElementCount(groupId.z)) return; const uint offset = totalInstanceCount + groupId.z * instanceOffset; const int arraySize = subArraySize << 1; const int arrayStart = arraySize * (id / arraySize); // If the current array considered is less than one half filled (due to element count), we can copy it directly as it is already sorted if (GetElementCount(groupId.z) - (uint)arrayStart <= subArraySize) { #if FINAL_PASS sortedSequence[id + offset] = inputSequence[id + offset].value; #else sortedSequence[id + offset] = inputSequence[id + offset]; #endif return; } const int arrayIndex = id - arrayStart; const int lastIndex0 = subArraySize - 1; const int lastIndex1 = min(subArraySize - 1, GetElementCount(groupId.z) - (arrayStart + subArraySize + 1)); // determine initial frame of the window int2 window = uint2(max(0, arrayIndex - lastIndex1 - 1), min(arrayIndex, lastIndex0)); int index0, index1; KVP kvp0, kvp1; bool reverse = false; bool done = false; #if DEBUG_NO_INFINITE_LOOP uint nbIter = 0; #endif // Binary search - O(log(subArraySize)) do { int windowIndex = (window.x + window.y + 1) >> 1; int i0 = min(lastIndex0, windowIndex); int i1 = min(lastIndex1, arrayIndex - windowIndex); index0 = i0 + arrayStart; index1 = i1 + arrayStart + subArraySize; kvp0 = inputSequence[index0 + offset]; kvp1 = inputSequence[index1 + offset]; if (i0 + i1 == arrayIndex) { if (i0 > 0 && CompareKVP(kvp1, inputSequence[max(0, index0 - 1) + offset])) window.y = windowIndex - 1; // move window left else if (i1 > 0 && CompareKVP(kvp0, inputSequence[max(0, index1 - 1) + offset])) window.x = windowIndex + (windowIndex == lastIndex0 ? 2 : 1); // move window right (Special handling at the right bound so that i1 can go down to 0) else done = true; } else // special case handling { reverse = true; done = true; } #if DEBUG_NO_INFINITE_LOOP if (++nbIter > 2048) { kvp0.key = kvp1.key = 0.5f; break; } #endif } while (!done); // Select left or right: Must handle special case for equality bool select0 = CompareKVP(kvp0, kvp1) != reverse; #if FINAL_PASS sortedSequence[id + offset] = select0 ? kvp0.value : kvp1.value; #else if (select0) kvp1 = kvp0; sortedSequence[id + offset] = kvp1; #endif }