ComputeShaderでSpringBone

環境

Unity2022.2.2f1

概要

UnityちゃんのSpringBoneの動きをComputeShaderで実装してみました。 (SDUnityちゃんに付属しているもの)

画像は白いのが本家のもの、Gizumoで表示されているのが、ComputeShaderで実装したものです。 (もう一つはテストでCPUでMatrix4x4で実装したもの)

同じ動きになりました。

コード

#pragma enable_d3d11_debug_symbols

#define Quaternion float4
#define Epsilon (1e-15)

float Sqrt(float v)
{
    float tv = sqrt(v);
    return (isfinite(tv) == false) ? 0.0 : tv;
}

float Rsqrt(float v)
{
    float tv = rsqrt(v);
    return (isfinite(tv) == false) ? 0.0 : tv;
}

float3 Normalize(float3 v)
{
    return Rsqrt(dot(v, v)) * v;
}

Quaternion AngleAxis(float aAngle, float3 aAxis)
{
    aAxis = Normalize(aAxis);
    float rad = aAngle * 0.5;
    aAxis *= sin(rad);
    return Quaternion(aAxis.x, aAxis.y, aAxis.z, cos(rad));
}

float SqrMagnitude(float3 v)
{
    return v.x * v.x + v.y * v.y + v.z * v.z;
}

float Vec3Angle(float3 from, float3 to)
{
    float denominator = Sqrt(SqrMagnitude(from) * SqrMagnitude(to));
    if (denominator < Epsilon)
        return 0.0;
    else
    {
        float d = clamp(dot(from, to) / denominator, -1.0, 1.0);
        return acos(d);
    }
}

Quaternion FromToRotation(float3 aFrom, float3 aTo)
{
    float3 axis = cross(aFrom, aTo);
    float angle = Vec3Angle(aFrom, aTo);
    if (angle >= 179.9196)
    {
        float3 r = cross(aFrom, float3(1.0, 0.0, 0.0));
        axis = cross(r, aFrom);
        if (SqrMagnitude(axis) < Epsilon)
            axis = float3(0.0, 1.0, 0.0);
    }
    return AngleAxis(angle, Normalize(axis));
}

Quaternion ToRotation(float3x3 m)
{
    Quaternion q;
    q.w = Sqrt(max(0.0, 1.0 + m._m00 + m._m11 + m._m22)) / 2.0;
    q.x = Sqrt(max(0.0, 1.0 + m._m00 - m._m11 - m._m22)) / 2.0;
    q.y = Sqrt(max(0.0, 1.0 - m._m00 + m._m11 - m._m22)) / 2.0;
    q.z = Sqrt(max(0.0, 1.0 - m._m00 - m._m11 + m._m22)) / 2.0;

    q.x *= sign(q.x * (m._m21 - m._m12));
    q.y *= sign(q.y * (m._m02 - m._m20));
    q.z *= sign(q.z * (m._m10 - m._m01));
    return q;
}

float3x3 ToMatrix(Quaternion q)
{
    float n = Rsqrt(q.x * q.x + q.y * q.y + q.z * q.z + q.w * q.w);
    q.x *= n;
    q.y *= n;
    q.z *= n;
    q.w *= n;

    return float3x3(
    1.0 - 2.0 * q.y * q.y - 2.0 * q.z * q.z, 2.0 * q.x * q.y - 2.0 * q.z * q.w, 2.0 * q.x * q.z + 2.0 * q.y * q.w,
    2.0 * q.x * q.y + 2.0 * q.z * q.w, 1.0 - 2.0 * q.x * q.x - 2.0 * q.z * q.z, 2.0 * q.y * q.z - 2.0 * q.x * q.w,
    2.0 * q.x * q.z - 2.0 * q.y * q.w, 2.0 * q.y * q.z + 2.0 * q.x * q.w, 1.0 - 2.0 * q.x * q.x - 2.0 * q.y * q.y
    );
}

float3 MulQuatVec3(Quaternion rotation, float3 p)
{
    float x = rotation.x * 2.0;
    float y = rotation.y * 2.0;
    float z = rotation.z * 2.0;
    float xx = rotation.x * x;
    float yy = rotation.y * y;
    float zz = rotation.z * z;
    float xy = rotation.x * y;
    float xz = rotation.x * z;
    float yz = rotation.y * z;
    float wx = rotation.w * x;
    float wy = rotation.w * y;
    float wz = rotation.w * z;

    float3 res;
    res.x = (1.0 - (yy + zz)) * p.x + (xy - wz) * p.y + (xz + wy) * p.z;
    res.y = (xy + wz) * p.x + (1.0 - (xx + zz)) * p.y + (yz - wx) * p.z;
    res.z = (xz - wy) * p.x + (yz + wx) * p.y + (1.0 - (xx + yy)) * p.z;
    return res;
}

Quaternion MulQuat(Quaternion lhs, Quaternion rhs)
{
    return Quaternion(
        lhs.w * rhs.x + lhs.x * rhs.w + lhs.y * rhs.z - lhs.z * rhs.y,
        lhs.w * rhs.y + lhs.y * rhs.w + lhs.z * rhs.x - lhs.x * rhs.z,
        lhs.w * rhs.z + lhs.z * rhs.w + lhs.x * rhs.y - lhs.y * rhs.x,
        lhs.w * rhs.w - lhs.x * rhs.x - lhs.y * rhs.y - lhs.z * rhs.z);
}

RWByteAddressBuffer SpringNodes;
uint ChainNum;
float4x4 WorldMatrix;
float DeltaTime;

/*
public Vector3 boneAxis;// = new Vector3 (-1.0f, 0.0f, 0.0f);
public float radius;// = 0.05f;
public float stiffnessForce;// = 0.01f;
public float dragForce;// = 0.4f;
public Vector3 springForce;// = new Vector3 (0.0f, -0.0001f, 0.0f);
public float springLength;
public Vector3 localPos;
public Quaternion localRotation;
//todo:値が更新されるのはここ以下なので本来は2つに分離した方が良いかも
public Matrix4x4 matrix;
public Vector3 currTipPos;
public Vector3 prevTipPos;
*/
#define SpringNodeStride 156
void UpdateSpring(uint chainIndex, inout float4x4 parentMatrix)
{
    uint address = chainIndex * SpringNodeStride;
    float3 boneAxis = asfloat(SpringNodes.Load3(address));
    address += 3 * 4;
    float radius = asfloat(SpringNodes.Load(address));
    address += 4;
    float stiffnessForce = asfloat(SpringNodes.Load(address));
    address += 4;
    float dragForce = asfloat(SpringNodes.Load(address));
    address += 4;
    float3 springForce = asfloat(SpringNodes.Load3(address));
    address += 3 * 4;
    float springLength = asfloat(SpringNodes.Load(address));
    address += 4;
    float3 localPos = asfloat(SpringNodes.Load3(address));
    address += 3 * 4;
    Quaternion localRotation = asfloat(SpringNodes.Load4(address));
    address += 4 * 4;

    uint outAddress = address;
    float4 m0 = asfloat(SpringNodes.Load4(address + 0 * 4));
    float4 m1 = asfloat(SpringNodes.Load4(address + 4 * 4));
    float4 m2 = asfloat(SpringNodes.Load4(address + 8 * 4));
    float4 m3 = asfloat(SpringNodes.Load4(address + 12 * 4));
    float4x4 springMatrix = transpose(float4x4(m0, m1, m2, m3));
    address += 4 * 4 * 4;
    float3 currTipPos = asfloat(SpringNodes.Load3(address));
    address += 3 * 4;
    float3 prevTipPos = asfloat(SpringNodes.Load3(address));
    //address += 3 * 4;

    //回転をリセット
    float3x3 localRotMatrix = ToMatrix(localRotation);
    float3x3 rotMatrix = mul((float3x3)parentMatrix, localRotMatrix);
    springMatrix._m00_m10_m20 = rotMatrix._m00_m10_m20;
    springMatrix._m01_m11_m21 = rotMatrix._m01_m11_m21;
    springMatrix._m02_m12_m22 = rotMatrix._m02_m12_m22;
    float4 pos = mul(parentMatrix, float4(localPos, 1.0));
    springMatrix._m03_m13_m23 = pos.xyz;// / pos.w;

    float sqrDt = DeltaTime * DeltaTime;
    //stiffness
    float3 force = MulQuatVec3(ToRotation((float3x3)springMatrix), boneAxis * (stiffnessForce / sqrDt));
    //drag
    force += (prevTipPos - currTipPos) * dragForce / sqrDt;
    force += springForce / sqrDt;

    //前フレームと値が同じにならないように
    float3 temp = currTipPos;
    //verlet
    currTipPos = (currTipPos - prevTipPos) + currTipPos + (force * sqrDt);
    //長さを元に戻す
    currTipPos = (Normalize(currTipPos - springMatrix._m03_m13_m23) * springLength) + springMatrix._m03_m13_m23;
    prevTipPos = temp;

    //回転を適用;
    float3 aimVector = mul((float3x3)springMatrix, boneAxis);
    Quaternion aimRotation = FromToRotation(aimVector, currTipPos - springMatrix._m03_m13_m23);
#if 1
    //original
    Quaternion rotation = MulQuat(aimRotation, ToRotation((float3x3)springMatrix));
#else
    //Kobayahsi:Lerp with mixWeight
    Quaternion secondaryRotation = aimRotation * matrix.rotation;
    var rotation = Quaternion.Lerp(matrix.rotation, secondaryRotation, 1.0f);
#endif
    rotMatrix = ToMatrix(rotation);
    springMatrix._m00_m10_m20 = rotMatrix._m00_m10_m20;
    springMatrix._m01_m11_m21 = rotMatrix._m01_m11_m21;
    springMatrix._m02_m12_m22 = rotMatrix._m02_m12_m22;
    parentMatrix = springMatrix;

    springMatrix = transpose(springMatrix);
    SpringNodes.Store4(outAddress +  0 * 4, asuint(springMatrix[0]));
    SpringNodes.Store4(outAddress +  4 * 4, asuint(springMatrix[1]));
    SpringNodes.Store4(outAddress +  8 * 4, asuint(springMatrix[2]));
    SpringNodes.Store4(outAddress + 12 * 4, asuint(springMatrix[3]));
    outAddress += 4 * 4 * 4;
    SpringNodes.Store3(outAddress, asuint(currTipPos));
    outAddress += 3 * 4;
    SpringNodes.Store3(outAddress, asuint(prevTipPos));
    //outAddress += 3 * 4;
}

#pragma kernel Springs
[numthreads(1,1,1)]
void Springs(uint3 id : SV_DispatchThreadID)
{
    float4x4 parentMatrix = WorldMatrix;
    for(uint i = 0; i < ChainNum; i++)
        UpdateSpring(i, parentMatrix);
}

参考

Whats the source code of Quaternion.FromToRotation? - Unity Answers