Skip to content

Commit 980347e

Browse files
ENhance NormalizedFloatToByteSaturate
1 parent 2d979f9 commit 980347e

14 files changed

Lines changed: 443 additions & 229 deletions

File tree

src/ImageSharp/Common/Helpers/Numerics.cs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1010,6 +1010,26 @@ public static nuint Vector256Count<TVector>(this ReadOnlySpan<byte> span)
10101010
where TVector : struct
10111011
=> (uint)span.Length / (uint)Vector256<TVector>.Count;
10121012

1013+
/// <summary>
1014+
/// Gets the count of vectors that safely fit into the given span.
1015+
/// </summary>
1016+
/// <typeparam name="TVector">The type of the vector.</typeparam>
1017+
/// <param name="span">The given span.</param>
1018+
/// <returns>Count of vectors that safely fit into the span.</returns>
1019+
public static nuint Vector512Count<TVector>(this Span<byte> span)
1020+
where TVector : struct
1021+
=> (uint)span.Length / (uint)Vector512<TVector>.Count;
1022+
1023+
/// <summary>
1024+
/// Gets the count of vectors that safely fit into the given span.
1025+
/// </summary>
1026+
/// <typeparam name="TVector">The type of the vector.</typeparam>
1027+
/// <param name="span">The given span.</param>
1028+
/// <returns>Count of vectors that safely fit into the span.</returns>
1029+
public static nuint Vector512Count<TVector>(this ReadOnlySpan<byte> span)
1030+
where TVector : struct
1031+
=> (uint)span.Length / (uint)Vector512<TVector>.Count;
1032+
10131033
/// <summary>
10141034
/// Gets the count of vectors that safely fit into the given span.
10151035
/// </summary>

src/ImageSharp/Common/Helpers/SimdUtils.ExtendedIntrinsics.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ internal static void NormalizedFloatToByteSaturateReduce(
9595
/// </summary>
9696
internal static void ByteToNormalizedFloat(ReadOnlySpan<byte> source, Span<float> dest)
9797
{
98-
VerifySpanInput(source, dest, Vector<byte>.Count);
98+
DebugVerifySpanInput(source, dest, Vector<byte>.Count);
9999

100100
nuint n = dest.VectorCount<byte>();
101101

@@ -130,7 +130,7 @@ internal static void NormalizedFloatToByteSaturate(
130130
ReadOnlySpan<float> source,
131131
Span<byte> dest)
132132
{
133-
VerifySpanInput(source, dest, Vector<byte>.Count);
133+
DebugVerifySpanInput(source, dest, Vector<byte>.Count);
134134

135135
nuint n = dest.VectorCount<byte>();
136136

src/ImageSharp/Common/Helpers/SimdUtils.FallbackIntrinsics128.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ internal static void NormalizedFloatToByteSaturateReduce(
6969
[MethodImpl(InliningOptions.ColdPath)]
7070
internal static void ByteToNormalizedFloat(ReadOnlySpan<byte> source, Span<float> dest)
7171
{
72-
VerifySpanInput(source, dest, 4);
72+
DebugVerifySpanInput(source, dest, 4);
7373

7474
uint count = (uint)dest.Length / 4;
7575
if (count == 0)
@@ -103,7 +103,7 @@ internal static void NormalizedFloatToByteSaturate(
103103
ReadOnlySpan<float> source,
104104
Span<byte> dest)
105105
{
106-
VerifySpanInput(source, dest, 4);
106+
DebugVerifySpanInput(source, dest, 4);
107107

108108
uint count = (uint)source.Length / 4;
109109
if (count == 0)

src/ImageSharp/Common/Helpers/SimdUtils.HwIntrinsics.cs

Lines changed: 97 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,13 @@ internal static partial class SimdUtils
1717
{
1818
public static class HwIntrinsics
1919
{
20+
#pragma warning disable SA1117 // Parameters should be on same line or separate lines
21+
#pragma warning disable SA1137 // Elements should have the same indentation
2022
[MethodImpl(MethodImplOptions.AggressiveInlining)] // too much IL for JIT to inline, so give a hint
21-
public static Vector256<int> PermuteMaskDeinterleave8x32() => Vector256.Create(0, 0, 0, 0, 4, 0, 0, 0, 1, 0, 0, 0, 5, 0, 0, 0, 2, 0, 0, 0, 6, 0, 0, 0, 3, 0, 0, 0, 7, 0, 0, 0).AsInt32();
23+
public static Vector256<int> PermuteMaskDeinterleave8x32() => Vector256.Create(0, 4, 1, 5, 2, 6, 3, 7);
24+
25+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
26+
public static Vector512<int> PermuteMaskDeinterleave16x32() => Vector512.Create(0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15);
2227

2328
[MethodImpl(MethodImplOptions.AggressiveInlining)]
2429
public static Vector256<uint> PermuteMaskEvenOdd8x32() => Vector256.Create(0, 0, 0, 0, 2, 0, 0, 0, 4, 0, 0, 0, 6, 0, 0, 0, 1, 0, 0, 0, 3, 0, 0, 0, 5, 0, 0, 0, 7, 0, 0, 0).AsUInt32();
@@ -38,17 +43,18 @@ public static class HwIntrinsics
3843
[MethodImpl(MethodImplOptions.AggressiveInlining)]
3944
private static Vector128<byte> ShuffleMaskSlice4Nx16() => Vector128.Create(0, 1, 2, 4, 5, 6, 8, 9, 10, 12, 13, 14, 0x80, 0x80, 0x80, 0x80);
4045

41-
#pragma warning disable SA1003, SA1116, SA1117 // Parameters should be on same line or separate lines
4246
[MethodImpl(MethodImplOptions.AggressiveInlining)]
43-
private static Vector256<byte> ShuffleMaskShiftAlpha() => Vector256.Create((byte)
44-
0, 1, 2, 4, 5, 6, 8, 9, 10, 12, 13, 14, 3, 7, 11, 15,
45-
0, 1, 2, 4, 5, 6, 8, 9, 10, 12, 13, 14, 3, 7, 11, 15);
47+
private static Vector256<byte> ShuffleMaskShiftAlpha() => Vector256.Create(
48+
(byte)0, 1, 2, 4, 5, 6, 8, 9, 10, 12, 13, 14, 3, 7, 11, 15,
49+
0, 1, 2, 4, 5, 6, 8, 9, 10, 12, 13, 14, 3, 7, 11, 15);
4650

4751
[MethodImpl(MethodImplOptions.AggressiveInlining)]
48-
public static Vector256<uint> PermuteMaskShiftAlpha8x32() => Vector256.Create(
49-
0, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 4, 0, 0, 0,
50-
5, 0, 0, 0, 6, 0, 0, 0, 3, 0, 0, 0, 7, 0, 0, 0).AsUInt32();
51-
#pragma warning restore SA1003, SA1116, SA1117 // Parameters should be on same line or separate lines
52+
public static Vector256<uint> PermuteMaskShiftAlpha8x32()
53+
=> Vector256.Create(
54+
0, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 4, 0, 0, 0,
55+
5, 0, 0, 0, 6, 0, 0, 0, 3, 0, 0, 0, 7, 0, 0, 0).AsUInt32();
56+
#pragma warning restore SA1137 // Elements should have the same indentation
57+
#pragma warning restore SA1117 // Parameters should be on same line or separate lines
5258

5359
/// <summary>
5460
/// Shuffle single-precision (32-bit) floating-point elements in <paramref name="source"/>
@@ -795,7 +801,7 @@ internal static unsafe void ByteToNormalizedFloat(
795801
{
796802
if (Avx2.IsSupported)
797803
{
798-
VerifySpanInput(source, dest, Vector256<byte>.Count);
804+
DebugVerifySpanInput(source, dest, Vector256<byte>.Count);
799805

800806
nuint n = dest.Vector256Count<byte>();
801807

@@ -828,7 +834,7 @@ internal static unsafe void ByteToNormalizedFloat(
828834
else
829835
{
830836
// Sse
831-
VerifySpanInput(source, dest, Vector128<byte>.Count);
837+
DebugVerifySpanInput(source, dest, Vector128<byte>.Count);
832838

833839
nuint n = dest.Vector128Count<byte>();
834840

@@ -881,17 +887,24 @@ internal static unsafe void ByteToNormalizedFloat(
881887
/// <summary>
882888
/// <see cref="NormalizedFloatToByteSaturate"/> as many elements as possible, slicing them down (keeping the remainder).
883889
/// </summary>
890+
/// <param name="source">The source buffer.</param>
891+
/// <param name="destination">The destination buffer.</param>
884892
[MethodImpl(InliningOptions.ShortMethod)]
885893
internal static void NormalizedFloatToByteSaturateReduce(
886894
ref ReadOnlySpan<float> source,
887-
ref Span<byte> dest)
895+
ref Span<byte> destination)
888896
{
889-
DebugGuard.IsTrue(source.Length == dest.Length, nameof(source), "Input spans must be of same length!");
897+
DebugGuard.IsTrue(source.Length == destination.Length, nameof(source), "Input spans must be of same length!");
890898

891-
if (Avx2.IsSupported || Sse2.IsSupported)
899+
if (Avx512BW.IsSupported || Avx2.IsSupported || Sse2.IsSupported || AdvSimd.IsSupported)
892900
{
893901
int remainder;
894-
if (Avx2.IsSupported)
902+
903+
if (Avx512BW.IsSupported)
904+
{
905+
remainder = Numerics.ModuloP2(source.Length, Vector512<byte>.Count);
906+
}
907+
else if (Avx2.IsSupported)
895908
{
896909
remainder = Numerics.ModuloP2(source.Length, Vector256<byte>.Count);
897910
}
@@ -906,36 +919,70 @@ internal static void NormalizedFloatToByteSaturateReduce(
906919
{
907920
NormalizedFloatToByteSaturate(
908921
source[..adjustedCount],
909-
dest[..adjustedCount]);
922+
destination[..adjustedCount]);
910923

911924
source = source[adjustedCount..];
912-
dest = dest[adjustedCount..];
925+
destination = destination[adjustedCount..];
913926
}
914927
}
915928
}
916929

917930
/// <summary>
918931
/// Implementation of <see cref="SimdUtils.NormalizedFloatToByteSaturate"/>, which is faster on new .NET runtime.
919932
/// </summary>
933+
/// <param name="source">The source buffer.</param>
934+
/// <param name="destination">The destination buffer.</param>
920935
/// <remarks>
921936
/// Implementation is based on MagicScaler code:
922937
/// https://github.com/saucecontrol/PhotoSauce/blob/b5811908041200488aa18fdfd17df5fc457415dc/src/MagicScaler/Magic/Processors/ConvertersFloat.cs#L541-L622
923938
/// </remarks>
924939
internal static void NormalizedFloatToByteSaturate(
925940
ReadOnlySpan<float> source,
926-
Span<byte> dest)
941+
Span<byte> destination)
927942
{
928-
if (Avx2.IsSupported)
943+
if (Avx512BW.IsSupported)
929944
{
930-
VerifySpanInput(source, dest, Vector256<byte>.Count);
945+
DebugVerifySpanInput(source, destination, Vector512<byte>.Count);
946+
947+
nuint n = destination.Vector512Count<byte>();
931948

932-
nuint n = dest.Vector256Count<byte>();
949+
ref Vector512<float> sourceBase = ref Unsafe.As<float, Vector512<float>>(ref MemoryMarshal.GetReference(source));
950+
ref Vector512<byte> destinationBase = ref Unsafe.As<byte, Vector512<byte>>(ref MemoryMarshal.GetReference(destination));
933951

934-
ref Vector256<float> sourceBase =
935-
ref Unsafe.As<float, Vector256<float>>(ref MemoryMarshal.GetReference(source));
952+
Vector512<float> scale = Vector512.Create((float)byte.MaxValue);
953+
Vector512<int> mask = PermuteMaskDeinterleave16x32();
936954

937-
ref Vector256<byte> destBase =
938-
ref Unsafe.As<byte, Vector256<byte>>(ref MemoryMarshal.GetReference(dest));
955+
for (nuint i = 0; i < n; i++)
956+
{
957+
ref Vector512<float> s = ref Unsafe.Add(ref sourceBase, i * 4);
958+
959+
Vector512<float> f0 = scale * s;
960+
Vector512<float> f1 = scale * Unsafe.Add(ref s, 1);
961+
Vector512<float> f2 = scale * Unsafe.Add(ref s, 2);
962+
Vector512<float> f3 = scale * Unsafe.Add(ref s, 3);
963+
964+
Vector512<int> w0 = Vector512Utilities.ConvertToInt32RoundToEven(f0);
965+
Vector512<int> w1 = Vector512Utilities.ConvertToInt32RoundToEven(f1);
966+
Vector512<int> w2 = Vector512Utilities.ConvertToInt32RoundToEven(f2);
967+
Vector512<int> w3 = Vector512Utilities.ConvertToInt32RoundToEven(f3);
968+
969+
Vector512<short> u0 = Avx512BW.PackSignedSaturate(w0, w1);
970+
Vector512<short> u1 = Avx512BW.PackSignedSaturate(w2, w3);
971+
Vector512<byte> b = Avx512BW.PackUnsignedSaturate(u0, u1);
972+
b = Avx512F.PermuteVar16x32(b.AsInt32(), mask).AsByte();
973+
974+
Unsafe.Add(ref destinationBase, i) = b;
975+
}
976+
}
977+
else
978+
if (Avx2.IsSupported)
979+
{
980+
DebugVerifySpanInput(source, destination, Vector256<byte>.Count);
981+
982+
nuint n = destination.Vector256Count<byte>();
983+
984+
ref Vector256<float> sourceBase = ref Unsafe.As<float, Vector256<float>>(ref MemoryMarshal.GetReference(source));
985+
ref Vector256<byte> destinationBase = ref Unsafe.As<byte, Vector256<byte>>(ref MemoryMarshal.GetReference(destination));
939986

940987
Vector256<float> scale = Vector256.Create((float)byte.MaxValue);
941988
Vector256<int> mask = PermuteMaskDeinterleave8x32();
@@ -944,57 +991,54 @@ internal static void NormalizedFloatToByteSaturate(
944991
{
945992
ref Vector256<float> s = ref Unsafe.Add(ref sourceBase, i * 4);
946993

947-
Vector256<float> f0 = Avx.Multiply(scale, s);
948-
Vector256<float> f1 = Avx.Multiply(scale, Unsafe.Add(ref s, 1));
949-
Vector256<float> f2 = Avx.Multiply(scale, Unsafe.Add(ref s, 2));
950-
Vector256<float> f3 = Avx.Multiply(scale, Unsafe.Add(ref s, 3));
994+
Vector256<float> f0 = scale * s;
995+
Vector256<float> f1 = scale * Unsafe.Add(ref s, 1);
996+
Vector256<float> f2 = scale * Unsafe.Add(ref s, 2);
997+
Vector256<float> f3 = scale * Unsafe.Add(ref s, 3);
951998

952-
Vector256<int> w0 = Avx.ConvertToVector256Int32(f0);
953-
Vector256<int> w1 = Avx.ConvertToVector256Int32(f1);
954-
Vector256<int> w2 = Avx.ConvertToVector256Int32(f2);
955-
Vector256<int> w3 = Avx.ConvertToVector256Int32(f3);
999+
Vector256<int> w0 = Vector256Utilities.ConvertToInt32RoundToEven(f0);
1000+
Vector256<int> w1 = Vector256Utilities.ConvertToInt32RoundToEven(f1);
1001+
Vector256<int> w2 = Vector256Utilities.ConvertToInt32RoundToEven(f2);
1002+
Vector256<int> w3 = Vector256Utilities.ConvertToInt32RoundToEven(f3);
9561003

9571004
Vector256<short> u0 = Avx2.PackSignedSaturate(w0, w1);
9581005
Vector256<short> u1 = Avx2.PackSignedSaturate(w2, w3);
9591006
Vector256<byte> b = Avx2.PackUnsignedSaturate(u0, u1);
9601007
b = Avx2.PermuteVar8x32(b.AsInt32(), mask).AsByte();
9611008

962-
Unsafe.Add(ref destBase, i) = b;
1009+
Unsafe.Add(ref destinationBase, i) = b;
9631010
}
9641011
}
9651012
else
9661013
{
967-
// Sse
968-
VerifySpanInput(source, dest, Vector128<byte>.Count);
969-
970-
nuint n = dest.Vector128Count<byte>();
1014+
// Sse, AdvSimd
1015+
DebugVerifySpanInput(source, destination, Vector128<byte>.Count);
9711016

972-
ref Vector128<float> sourceBase =
973-
ref Unsafe.As<float, Vector128<float>>(ref MemoryMarshal.GetReference(source));
1017+
nuint n = destination.Vector128Count<byte>();
9741018

975-
ref Vector128<byte> destBase =
976-
ref Unsafe.As<byte, Vector128<byte>>(ref MemoryMarshal.GetReference(dest));
1019+
ref Vector128<float> sourceBase = ref Unsafe.As<float, Vector128<float>>(ref MemoryMarshal.GetReference(source));
1020+
ref Vector128<byte> destinationBase = ref Unsafe.As<byte, Vector128<byte>>(ref MemoryMarshal.GetReference(destination));
9771021

9781022
Vector128<float> scale = Vector128.Create((float)byte.MaxValue);
9791023

9801024
for (nuint i = 0; i < n; i++)
9811025
{
9821026
ref Vector128<float> s = ref Unsafe.Add(ref sourceBase, i * 4);
9831027

984-
Vector128<float> f0 = Sse.Multiply(scale, s);
985-
Vector128<float> f1 = Sse.Multiply(scale, Unsafe.Add(ref s, 1));
986-
Vector128<float> f2 = Sse.Multiply(scale, Unsafe.Add(ref s, 2));
987-
Vector128<float> f3 = Sse.Multiply(scale, Unsafe.Add(ref s, 3));
1028+
Vector128<float> f0 = scale * s;
1029+
Vector128<float> f1 = scale * Unsafe.Add(ref s, 1);
1030+
Vector128<float> f2 = scale * Unsafe.Add(ref s, 2);
1031+
Vector128<float> f3 = scale * Unsafe.Add(ref s, 3);
9881032

989-
Vector128<int> w0 = Sse2.ConvertToVector128Int32(f0);
990-
Vector128<int> w1 = Sse2.ConvertToVector128Int32(f1);
991-
Vector128<int> w2 = Sse2.ConvertToVector128Int32(f2);
992-
Vector128<int> w3 = Sse2.ConvertToVector128Int32(f3);
1033+
Vector128<int> w0 = Vector128Utilities.ConvertToInt32RoundToEven(f0);
1034+
Vector128<int> w1 = Vector128Utilities.ConvertToInt32RoundToEven(f1);
1035+
Vector128<int> w2 = Vector128Utilities.ConvertToInt32RoundToEven(f2);
1036+
Vector128<int> w3 = Vector128Utilities.ConvertToInt32RoundToEven(f3);
9931037

994-
Vector128<short> u0 = Sse2.PackSignedSaturate(w0, w1);
995-
Vector128<short> u1 = Sse2.PackSignedSaturate(w2, w3);
1038+
Vector128<short> u0 = Vector128Utilities.PackSignedSaturate(w0, w1);
1039+
Vector128<short> u1 = Vector128Utilities.PackSignedSaturate(w2, w3);
9961040

997-
Unsafe.Add(ref destBase, i) = Sse2.PackUnsignedSaturate(u0, u1);
1041+
Unsafe.Add(ref destinationBase, i) = Vector128Utilities.PackUnsignedSaturate(u0, u1);
9981042
}
9991043
}
10001044
}

0 commit comments

Comments
 (0)