325 lines
10 KiB
Plaintext
325 lines
10 KiB
Plaintext
#pragma kernel BitonicSort128 BITONIC_SORT=BitonicSort128 ELEMENTS_COUNT=128 ITERATIONS_COUNT=1 FINAL_PASS=1
|
|
|
|
#pragma kernel BitonicSort1024 BITONIC_SORT=BitonicSort1024 ELEMENTS_COUNT=1024 ITERATIONS_COUNT=2 FINAL_PASS=1
|
|
#pragma kernel BitonicSort1024_128 BITONIC_SORT=BitonicSort1024_128 ELEMENTS_COUNT=1024 ITERATIONS_COUNT=4 FINAL_PASS=1
|
|
|
|
#pragma kernel BitonicSort2048 BITONIC_SORT=BitonicSort2048 ELEMENTS_COUNT=2048 ITERATIONS_COUNT=2 FINAL_PASS=1
|
|
#pragma kernel BitonicSort2048_128 BITONIC_SORT=BitonicSort2048_128 ELEMENTS_COUNT=2048 ITERATIONS_COUNT=8 FINAL_PASS=1
|
|
|
|
#pragma kernel BitonicSort4096 BITONIC_SORT=BitonicSort4096 ELEMENTS_COUNT=4096 ITERATIONS_COUNT=2 FINAL_PASS=1
|
|
#pragma kernel BitonicSort4096_128 BITONIC_SORT=BitonicSort4096_128 ELEMENTS_COUNT=4096 ITERATIONS_COUNT=16 FINAL_PASS=1
|
|
|
|
#pragma kernel BitonicPrePass BITONIC_SORT=BitonicPrePass ELEMENTS_COUNT=4096 ITERATIONS_COUNT=2 FINAL_PASS=0
|
|
#pragma kernel BitonicPrePass4096_128 BITONIC_SORT=BitonicPrePass4096_128 ELEMENTS_COUNT=4096 ITERATIONS_COUNT=16 FINAL_PASS=0
|
|
#pragma kernel BitonicPrePass2048_128 BITONIC_SORT=BitonicPrePass2048_128 ELEMENTS_COUNT=2048 ITERATIONS_COUNT=8 FINAL_PASS=0
|
|
|
|
#pragma kernel MergePass MERGE_PASS=MergePass FINAL_PASS=0
|
|
#pragma kernel MergeFinalPass MERGE_PASS=MergeFinalPass FINAL_PASS=1
|
|
|
|
#pragma only_renderers d3d11 playstation xboxone xboxseries vulkan metal switch glcore gles3 webgpu
|
|
|
|
#include "HLSLSupport.cginc"
|
|
|
|
#define DECREASING_ORDER 1
|
|
#define DEBUG_NO_INFINITE_LOOP 0
|
|
|
|
#ifndef ELEMENTS_COUNT
|
|
#define ELEMENTS_COUNT 1024
|
|
#endif
|
|
|
|
#ifndef ITERATIONS_COUNT
|
|
#define ITERATIONS_COUNT 1
|
|
#endif
|
|
|
|
// 1 to use the alternative representation of the bitonic network:
|
|
// No comparison flipping but dedicated first sub pass.
|
|
// see https://en.wikipedia.org/wiki/Bitonic_sorter
|
|
#if defined(SHADER_API_METAL)
|
|
//Workaround BitonicSort128 doesn't behave as expected, actual cause isn't identified yet.
|
|
#define USE_ALTERNATE_BITONIC_NETWORK 1
|
|
#else
|
|
#define USE_ALTERNATE_BITONIC_NETWORK 0
|
|
#endif
|
|
|
|
#define ELEMENTS_PER_THREAD (2 * ITERATIONS_COUNT)
|
|
#define BITONIC_THREADS_COUNT ELEMENTS_COUNT / ELEMENTS_PER_THREAD
|
|
#define MERGE_THREADS_COUNT ELEMENTS_COUNT
|
|
|
|
#define HALF_SCRATCH_SIZE ELEMENTS_COUNT
|
|
|
|
#define LDS_VALUES_OFFSET 1
|
|
#define SCRATCH_SIZE (HALF_SCRATCH_SIZE << 1)
|
|
|
|
#define ITERATION_INDEX(id,it) ((it) * BITONIC_THREADS_COUNT + (id))
|
|
#define INSTANCE_DST_INDEX(groupId,threadId) (groupId.x * BITONIC_THREADS_COUNT * ELEMENTS_PER_THREAD + threadId)
|
|
#define DST_INDEX(groupId,threadId) (INSTANCE_DST_INDEX(groupId,threadId) + groupId.z * instanceOffset)
|
|
|
|
#if DECREASING_ORDER
|
|
#define REJECTED_VALUE asfloat(0xff7fffff) // -MAX_FLOAT
|
|
#else
|
|
#define REJECTED_VALUE asfloat(0x7f7fffff) // MAX_FLOAT
|
|
#endif
|
|
|
|
#pragma warning(disable : 3557) // disable warning for auto unrolling of single iteration loop
|
|
|
|
struct KVP
|
|
{
|
|
float key;
|
|
uint value;
|
|
};
|
|
|
|
StructuredBuffer<KVP> inputSequence;
|
|
#if FINAL_PASS
|
|
RWStructuredBuffer<uint> sortedSequence;
|
|
#else
|
|
RWStructuredBuffer<KVP> sortedSequence;
|
|
#endif
|
|
|
|
// Layout of the scratch memory is as follows:
|
|
// {Key,Value},{Key,Value}, ...
|
|
groupshared float scratch[SCRATCH_SIZE];
|
|
|
|
CBUFFER_START(SortConstants)
|
|
uint instanceOffset;
|
|
uint totalInstanceCount;
|
|
CBUFFER_END
|
|
|
|
uint GetElementCount(uint instanceIndex)
|
|
{
|
|
return inputSequence[instanceIndex].value;
|
|
}
|
|
|
|
bool CompareKVP(KVP kvp0, KVP kvp1)
|
|
{
|
|
if (kvp0.key == kvp1.key)
|
|
return kvp0.value < kvp1.value; // value always in ascending order
|
|
else
|
|
#if DECREASING_ORDER
|
|
return kvp0.key > kvp1.key;
|
|
#else
|
|
return kvp0.key < kvp1.key;
|
|
#endif
|
|
}
|
|
|
|
uint GetLDSIndex(uint index)
|
|
{
|
|
return 2 * index;
|
|
}
|
|
|
|
void LoadFromMemory(uint3 groupId, uint ldsIndex)
|
|
{
|
|
uint instanceDstIndex = INSTANCE_DST_INDEX(groupId, ldsIndex);
|
|
uint memIndex = DST_INDEX(groupId, ldsIndex);
|
|
|
|
KVP kvp = { REJECTED_VALUE, 0xFFFFFFFF };
|
|
if (instanceDstIndex < GetElementCount(groupId.z))
|
|
kvp = inputSequence[totalInstanceCount + memIndex];
|
|
|
|
uint paddedLdsIndex = GetLDSIndex(ldsIndex);
|
|
scratch[paddedLdsIndex] = kvp.key;
|
|
scratch[paddedLdsIndex + LDS_VALUES_OFFSET] = asfloat(kvp.value);
|
|
}
|
|
|
|
void StoreToMemory(uint3 groupId, uint ldsIndex)
|
|
{
|
|
uint instanceDstIndex = INSTANCE_DST_INDEX(groupId, ldsIndex);
|
|
uint memIndex = DST_INDEX(groupId, ldsIndex);
|
|
|
|
if (instanceDstIndex < GetElementCount(groupId.z))
|
|
{
|
|
uint paddedLdsIndex = GetLDSIndex(ldsIndex);
|
|
uint value = asuint(scratch[paddedLdsIndex + LDS_VALUES_OFFSET]);
|
|
#if FINAL_PASS
|
|
sortedSequence[totalInstanceCount + memIndex] = value;
|
|
#else
|
|
float key = scratch[paddedLdsIndex];
|
|
KVP kvp = { key, value };
|
|
sortedSequence[totalInstanceCount + memIndex] = kvp;
|
|
#endif
|
|
}
|
|
}
|
|
|
|
// Bitonic sort on small chunks of kvp of size ELEMENTS_COUNT - execute in O(log²(ELEMENTS_COUNT))
|
|
[numthreads(BITONIC_THREADS_COUNT, 1, 1)]
|
|
void BITONIC_SORT(uint id : SV_GroupIndex, uint3 groupId : SV_GroupID)
|
|
{
|
|
// Skip useless groups
|
|
if (groupId.x > GetElementCount(groupId.z) / ELEMENTS_COUNT)
|
|
return;
|
|
if (groupId.x == 0 && id == 0)
|
|
#if FINAL_PASS
|
|
sortedSequence[groupId.z] = inputSequence[groupId.z].value;
|
|
#else
|
|
sortedSequence[groupId.z] = inputSequence[groupId.z];
|
|
#endif
|
|
// Load data from memory to LDS
|
|
//[unroll]
|
|
for (uint i = 0; i < ELEMENTS_PER_THREAD; ++i)
|
|
{
|
|
uint index = BITONIC_THREADS_COUNT * i + id;
|
|
LoadFromMemory(groupId,index);
|
|
}
|
|
|
|
GroupMemoryBarrierWithGroupSync(); // LDS Writes visible
|
|
|
|
for (uint step = 1; step < ELEMENTS_COUNT; step <<= 1) // O(log(ELEMENTS_COUNT))
|
|
for (uint subStep = step; subStep != 0; subStep >>= 1) // O(log(step))
|
|
{
|
|
[unroll]
|
|
for (uint i = 0; i < ITERATIONS_COUNT; ++i)
|
|
{
|
|
uint index = ITERATION_INDEX(id, i);
|
|
uint lsb = index & (subStep - 1);
|
|
uint index0 = (2 * index) - lsb;
|
|
uint index1 = index0 + subStep;
|
|
#if USE_ALTERNATE_BITONIC_NETWORK
|
|
if (subStep == step)
|
|
index1 += step - (2 * lsb) - 1;
|
|
#endif
|
|
uint ldsIndex0 = GetLDSIndex(index0);
|
|
uint ldsIndex1 = GetLDSIndex(index1);
|
|
|
|
float key0 = scratch[ldsIndex0];
|
|
float value0 = scratch[ldsIndex0 + LDS_VALUES_OFFSET];
|
|
|
|
float key1 = scratch[ldsIndex1];
|
|
float value1 = scratch[ldsIndex1 + LDS_VALUES_OFFSET];
|
|
|
|
|
|
KVP kvp0 = { key0, asuint(value0) };
|
|
KVP kvp1 = { key1, asuint(value1) };
|
|
|
|
#if USE_ALTERNATE_BITONIC_NETWORK
|
|
bool reverse = false;
|
|
#else
|
|
bool reverse = index & step;
|
|
#endif
|
|
if (CompareKVP(kvp1, kvp0) != reverse)
|
|
{
|
|
// swap first
|
|
scratch[ldsIndex0] = key1;
|
|
scratch[ldsIndex0 + LDS_VALUES_OFFSET] = value1;
|
|
|
|
// swap second
|
|
scratch[ldsIndex1] = key0;
|
|
scratch[ldsIndex1 + LDS_VALUES_OFFSET] = value0;
|
|
}
|
|
}
|
|
|
|
GroupMemoryBarrierWithGroupSync(); // LDS Writes visible
|
|
}
|
|
|
|
// Store sorted data from LDS to memory
|
|
//[unroll]
|
|
for (uint j = 0; j < ELEMENTS_PER_THREAD; ++j)
|
|
{
|
|
uint index = BITONIC_THREADS_COUNT * j + id; //(id / 32) * 64 * ITERATIONS_COUNT + (id & 31) + j * 32;//id * ITERATIONS_COUNT * 2 + j;
|
|
StoreToMemory(groupId, index);
|
|
}
|
|
}
|
|
|
|
CBUFFER_START(MergePassConstants)
|
|
uint subArraySize;
|
|
uint dispatchWidth;
|
|
CBUFFER_END
|
|
|
|
// Merge pass: take N sorted sub arrays as input and output N/2 sorted arrays twice bigger - execute in O(log(subArraySize))
|
|
#define NB_THREADS_PER_GROUP_MERGE_PASS 64
|
|
[numthreads(NB_THREADS_PER_GROUP_MERGE_PASS, 1, 1)]
|
|
void MERGE_PASS(uint3 groupId : SV_GroupID,
|
|
uint3 groupThreadId : SV_GroupThreadID)
|
|
{
|
|
uint id = groupThreadId.x + groupId.x * NB_THREADS_PER_GROUP_MERGE_PASS + groupId.y * dispatchWidth * NB_THREADS_PER_GROUP_MERGE_PASS;
|
|
|
|
#if FINAL_PASS
|
|
if (groupId.x == 0 && id == 0)
|
|
sortedSequence[groupId.z] = inputSequence[groupId.z].value;
|
|
#endif
|
|
|
|
if (id >= GetElementCount(groupId.z))
|
|
return;
|
|
|
|
const uint offset = totalInstanceCount + groupId.z * instanceOffset;
|
|
|
|
const int arraySize = subArraySize << 1;
|
|
const int arrayStart = arraySize * (id / arraySize);
|
|
|
|
// If the current array considered is less than one half filled (due to element count), we can copy it directly as it is already sorted
|
|
if (GetElementCount(groupId.z) - (uint)arrayStart <= subArraySize)
|
|
{
|
|
#if FINAL_PASS
|
|
sortedSequence[id + offset] = inputSequence[id + offset].value;
|
|
#else
|
|
sortedSequence[id + offset] = inputSequence[id + offset];
|
|
#endif
|
|
return;
|
|
}
|
|
|
|
const int arrayIndex = id - arrayStart;
|
|
const int lastIndex0 = subArraySize - 1;
|
|
const int lastIndex1 = min(subArraySize - 1, GetElementCount(groupId.z) - (arrayStart + subArraySize + 1));
|
|
|
|
|
|
// determine initial frame of the window
|
|
int2 window = uint2(max(0, arrayIndex - lastIndex1 - 1), min(arrayIndex, lastIndex0));
|
|
|
|
int index0, index1;
|
|
KVP kvp0, kvp1;
|
|
|
|
bool reverse = false;
|
|
bool done = false;
|
|
#if DEBUG_NO_INFINITE_LOOP
|
|
uint nbIter = 0;
|
|
#endif
|
|
|
|
// Binary search - O(log(subArraySize))
|
|
do
|
|
{
|
|
int windowIndex = (window.x + window.y + 1) >> 1;
|
|
int i0 = min(lastIndex0, windowIndex);
|
|
int i1 = min(lastIndex1, arrayIndex - windowIndex);
|
|
|
|
index0 = i0 + arrayStart;
|
|
index1 = i1 + arrayStart + subArraySize;
|
|
|
|
kvp0 = inputSequence[index0 + offset];
|
|
kvp1 = inputSequence[index1 + offset];
|
|
|
|
if (i0 + i1 == arrayIndex)
|
|
{
|
|
if (i0 > 0 && CompareKVP(kvp1, inputSequence[max(0, index0 - 1) + offset]))
|
|
window.y = windowIndex - 1; // move window left
|
|
else if (i1 > 0 && CompareKVP(kvp0, inputSequence[max(0, index1 - 1) + offset]))
|
|
window.x = windowIndex + (windowIndex == lastIndex0 ? 2 : 1); // move window right (Special handling at the right bound so that i1 can go down to 0)
|
|
else
|
|
done = true;
|
|
}
|
|
else // special case handling
|
|
{
|
|
reverse = true;
|
|
done = true;
|
|
}
|
|
|
|
#if DEBUG_NO_INFINITE_LOOP
|
|
if (++nbIter > 2048)
|
|
{
|
|
kvp0.key = kvp1.key = 0.5f;
|
|
break;
|
|
}
|
|
#endif
|
|
|
|
} while (!done);
|
|
|
|
// Select left or right: Must handle special case for equality
|
|
bool select0 = CompareKVP(kvp0, kvp1) != reverse;
|
|
|
|
#if FINAL_PASS
|
|
sortedSequence[id + offset] = select0 ? kvp0.value : kvp1.value;
|
|
#else
|
|
if (select0)
|
|
kvp1 = kvp0;
|
|
sortedSequence[id + offset] = kvp1;
|
|
#endif
|
|
|
|
}
|