199 lines
6.5 KiB
HLSL
199 lines
6.5 KiB
HLSL
|
#ifndef THREADING
|
||
|
#define THREADING
|
||
|
|
||
|
///
|
||
|
/// Compute Shader Threading Utilities
|
||
|
///
|
||
|
/// This file is intended to provide a portable implementation of the wave-level operations in DirectX Shader Model 6.0.
|
||
|
///
|
||
|
/// The functions in this file will automatically resolve to native intrinsics when possible.
|
||
|
/// A fallback groupshared memory implementation is used when native support is not available.
|
||
|
///
|
||
|
/// Usage:
|
||
|
///
|
||
|
/// To use this file, define all required preprocessor symbols and then include this file in your compute shader.
|
||
|
///
|
||
|
/// Required Preprocessor Symbols:
|
||
|
///
|
||
|
/// THREADING_BLOCK_SIZE
|
||
|
/// - The size of the compute shader's flattened thread group size
|
||
|
///
|
||
|
/// Optional Preprocessor Symbols:
|
||
|
///
|
||
|
/// THREADING_WAVE_SIZE
|
||
|
/// - The size of a wave within the compute shader
|
||
|
/// - This symbol MUST be defined when authoring shader code that requires a specific wave size for correctness!
|
||
|
///
|
||
|
/// THREADING_FORCE_WAVE_EMULATION
|
||
|
/// - If defined, forces usage of the fallback groupshared memory implementation
|
||
|
///
|
||
|
|
||
|
#ifndef THREADING_BLOCK_SIZE
|
||
|
#error THREADING_BLOCK_SIZE must be defined as the flattened thread group size.
|
||
|
#endif
|
||
|
|
||
|
// The emulation path is automatically enabled when we're running on hardware that doesn't meet minimum requirements.
|
||
|
//
|
||
|
// In order to use the non-emulated path, the current device must have native support for wave-level operations.
|
||
|
// If THREADING_WAVE_SIZE is provided, then the device's wave size must also match the size specified by THREADING_WAVE_SIZE.
|
||
|
//
|
||
|
// The emulation path can also be forced on via the THREADING_FORCE_WAVE_EMULATION preprocessor symbol for debug/testing purposes.
|
||
|
#define _THREADING_IS_HW_SUPPORTED (defined(UNITY_HW_SUPPORTS_WAVE) && (!defined(THREADING_WAVE_SIZE) || (defined(UNITY_HW_WAVE_SIZE) && (UNITY_HW_WAVE_SIZE == THREADING_WAVE_SIZE))))
|
||
|
#define _THREADING_ENABLE_WAVE_EMULATION (!_THREADING_IS_HW_SUPPORTED || defined(THREADING_FORCE_WAVE_EMULATION))
|
||
|
#define _THREADING_GROUP_BALLOT_DWORDS ((THREADING_BLOCK_SIZE + 31u) / 32u)
|
||
|
|
||
|
namespace Threading
|
||
|
{
|
||
|
struct Wave
|
||
|
{
|
||
|
// Unfortunately 'private' is a reserved keyword in HLSL.
|
||
|
uint indexG;
|
||
|
uint indexW;
|
||
|
#if _THREADING_ENABLE_WAVE_EMULATION
|
||
|
uint indexL;
|
||
|
uint offset; // Per-wave offset into LDS scratch space.
|
||
|
#endif
|
||
|
|
||
|
uint GetIndex();
|
||
|
|
||
|
void Init(uint groupIndex);
|
||
|
|
||
|
#define DECLARE_API_FOR_TYPE(TYPE) \
|
||
|
bool AllEqual(TYPE v); \
|
||
|
TYPE Product(TYPE v); \
|
||
|
TYPE Sum(TYPE v); \
|
||
|
TYPE Max(TYPE v); \
|
||
|
TYPE Min(TYPE v); \
|
||
|
TYPE InclusivePrefixSum(TYPE v); \
|
||
|
TYPE InclusivePrefixProduct(TYPE v); \
|
||
|
TYPE PrefixSum(TYPE v); \
|
||
|
TYPE PrefixProduct(TYPE v); \
|
||
|
TYPE ReadLaneAt(TYPE v, uint i); \
|
||
|
TYPE ReadLaneFirst(TYPE v); \
|
||
|
|
||
|
// Currently just support scalars.
|
||
|
DECLARE_API_FOR_TYPE(uint)
|
||
|
DECLARE_API_FOR_TYPE(int)
|
||
|
DECLARE_API_FOR_TYPE(float)
|
||
|
|
||
|
// The following intrinsics need only be declared once.
|
||
|
uint GetLaneCount();
|
||
|
uint GetLaneIndex();
|
||
|
bool IsFirstLane();
|
||
|
bool AllTrue(bool v);
|
||
|
bool AnyTrue(bool v);
|
||
|
uint4 Ballot(bool v);
|
||
|
uint CountBits(bool v);
|
||
|
uint PrefixCountBits(bool v);
|
||
|
uint And(uint v);
|
||
|
uint Or(uint v);
|
||
|
uint Xor(uint v);
|
||
|
};
|
||
|
|
||
|
struct GroupBallot
|
||
|
{
|
||
|
uint dwords[_THREADING_GROUP_BALLOT_DWORDS];
|
||
|
|
||
|
uint CountBits()
|
||
|
{
|
||
|
uint result = 0;
|
||
|
|
||
|
[unroll]
|
||
|
for (uint dwordIndex = 0; dwordIndex < _THREADING_GROUP_BALLOT_DWORDS; ++dwordIndex)
|
||
|
{
|
||
|
result += countbits(dwords[dwordIndex]);
|
||
|
}
|
||
|
|
||
|
return result;
|
||
|
}
|
||
|
};
|
||
|
|
||
|
struct Group
|
||
|
{
|
||
|
uint groupIndex : SV_GroupIndex;
|
||
|
uint3 groupID : SV_GroupID;
|
||
|
uint3 dispatchID : SV_DispatchThreadID;
|
||
|
|
||
|
Wave GetWave()
|
||
|
{
|
||
|
Wave wave;
|
||
|
{
|
||
|
wave = (Wave)0;
|
||
|
wave.Init(groupIndex);
|
||
|
}
|
||
|
return wave;
|
||
|
}
|
||
|
|
||
|
// Lane remap which is safe for both portability (different wave sizes up to 128) and for 2D wave reductions.
|
||
|
// 6543210
|
||
|
// =======
|
||
|
// ..xx..x
|
||
|
// yy..yy.
|
||
|
// Details,
|
||
|
// LANE TO 8x16 MAPPING
|
||
|
// ====================
|
||
|
// 00 01 08 09 10 11 18 19
|
||
|
// 02 03 0a 0b 12 13 1a 1b
|
||
|
// 04 05 0c 0d 14 15 1c 1d
|
||
|
// 06 07 0e 0f 16 17 1e 1f
|
||
|
// 20 21 28 29 30 31 38 39
|
||
|
// 22 23 2a 2b 32 33 3a 3b
|
||
|
// 24 25 2c 2d 34 35 3c 3d
|
||
|
// 26 27 2e 2f 36 37 3e 3f
|
||
|
// .......................
|
||
|
// ... repeat the 8x8 ....
|
||
|
// .... pattern, but .....
|
||
|
// .... for 40 to 7f .....
|
||
|
// .......................
|
||
|
// NOTE: This function is only intended to be used with one dimensional thread groups
|
||
|
uint2 RemapLaneTo8x16()
|
||
|
{
|
||
|
// Note the BFIs used for MSBs have "strange offsets" due to leaving space for the LSB bits replaced in the BFI.
|
||
|
return uint2(BitFieldInsert(1u, groupIndex, BitFieldExtract(groupIndex, 2u, 3u)),
|
||
|
BitFieldInsert(3u, BitFieldExtract(groupIndex, 1u, 2u), BitFieldExtract(groupIndex, 3u, 4u)));
|
||
|
}
|
||
|
|
||
|
uint GetWaveCount();
|
||
|
|
||
|
#define DECLARE_API_FOR_TYPE_GROUP(TYPE) \
|
||
|
bool AllEqual(TYPE v); \
|
||
|
TYPE Product(TYPE v); \
|
||
|
TYPE Sum(TYPE v); \
|
||
|
TYPE Max(TYPE v); \
|
||
|
TYPE Min(TYPE v); \
|
||
|
TYPE InclusivePrefixSum(TYPE v); \
|
||
|
TYPE InclusivePrefixProduct(TYPE v); \
|
||
|
TYPE PrefixSum(TYPE v); \
|
||
|
TYPE PrefixProduct(TYPE v); \
|
||
|
TYPE ReadThreadAt(TYPE v, uint i); \
|
||
|
TYPE ReadThreadFirst(TYPE v); \
|
||
|
TYPE ReadThreadShuffle(TYPE v, uint i); \
|
||
|
|
||
|
// Currently just support scalars.
|
||
|
DECLARE_API_FOR_TYPE_GROUP(uint)
|
||
|
DECLARE_API_FOR_TYPE_GROUP(int)
|
||
|
DECLARE_API_FOR_TYPE_GROUP(float)
|
||
|
|
||
|
// The following intrinsics need only be declared once.
|
||
|
uint GetThreadCount();
|
||
|
uint GetThreadIndex();
|
||
|
bool IsFirstThread();
|
||
|
bool AllTrue(bool v);
|
||
|
bool AnyTrue(bool v);
|
||
|
GroupBallot Ballot(bool v);
|
||
|
uint CountBits(bool v);
|
||
|
uint PrefixCountBits(bool v);
|
||
|
uint And(uint v);
|
||
|
uint Or(uint v);
|
||
|
uint Xor(uint v);
|
||
|
};
|
||
|
}
|
||
|
|
||
|
#if _THREADING_ENABLE_WAVE_EMULATION
|
||
|
#include "ThreadingEmuImpl.hlsl"
|
||
|
#else
|
||
|
#include "ThreadingSM6Impl.hlsl"
|
||
|
#endif
|
||
|
|
||
|
#endif
|