/*
 * Decompiled with CFR 0.152.
 */
package net.countered.counteredsaccuratehitboxes.mixin.server;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Optional;
import java.util.function.Predicate;
import net.countered.counteredsaccuratehitboxes.util.HitboxAttachment;
import net.countered.counteredsaccuratehitboxes.util.Triangle;
import net.minecraft.class_1297;
import net.minecraft.class_1675;
import net.minecraft.class_1937;
import net.minecraft.class_238;
import net.minecraft.class_2394;
import net.minecraft.class_2398;
import net.minecraft.class_243;
import net.minecraft.class_3966;
import org.joml.Vector3f;
import org.joml.Vector3fc;
import org.spongepowered.asm.mixin.Mixin;
import org.spongepowered.asm.mixin.Unique;
import org.spongepowered.asm.mixin.injection.At;
import org.spongepowered.asm.mixin.injection.Inject;
import org.spongepowered.asm.mixin.injection.callback.CallbackInfoReturnable;

@Mixin(value={class_1675.class})
public abstract class ModProjectileUtilMixin {
    @Inject(method={"raycast"}, at={@At(value="HEAD")}, cancellable=true)
    private static void injectCustomHitboxRaycast(class_1297 entity, class_243 min, class_243 max, class_238 box, Predicate<class_1297> predicate, double maxDistance, CallbackInfoReturnable<class_3966> cir) {
        class_1937 world = entity.method_37908();
        double closestDistance = maxDistance;
        class_1297 closestEntity = null;
        class_243 hitPos = null;
        class_243 rayDir = max.method_1020(min).method_1029();
        double rayLength = min.method_1022(max);
        for (class_1297 target : world.method_8333(entity, box, predicate)) {
            List hitboxes = (List)target.getAttached(HitboxAttachment.HITBOXES);
            if (hitboxes == null || hitboxes.isEmpty()) {
                class_243 intersect;
                double distance;
                class_238 fallback = target.method_5829().method_1014((double)target.method_5871());
                Optional optional = fallback.method_992(min, max);
                if (fallback.method_1006(min)) {
                    double distance2 = 0.0;
                    if (!(distance2 < closestDistance) && closestDistance != 0.0 || target.method_5668().equals((Object)entity.method_5668()) && closestDistance != 0.0) continue;
                    closestEntity = target;
                    hitPos = optional.orElse(min);
                    closestDistance = 0.0;
                    continue;
                }
                if (!optional.isPresent() || !((distance = min.method_1025(intersect = (class_243)optional.get())) < closestDistance) && closestDistance != 0.0 || target.method_5668().equals((Object)entity.method_5668()) && closestDistance != 0.0) continue;
                closestEntity = target;
                hitPos = intersect;
                closestDistance = distance;
                continue;
            }
            for (List<Vector3f> cubeVerts : hitboxes) {
                if (cubeVerts.size() == 4) {
                    cubeVerts = ModProjectileUtilMixin.inflateQuadToBox(cubeVerts, 0.01f);
                }
                if (cubeVerts.size() != 8) continue;
                List<Vector3f> sortedVerts = ModProjectileUtilMixin.sortVertices(cubeVerts);
                List<Triangle> triangles = ModProjectileUtilMixin.buildCubeTriangles(sortedVerts);
                for (Triangle triangle : triangles) {
                    class_243 intersect;
                    double distance;
                    Optional<class_243> intersection = ModProjectileUtilMixin.rayTriangleIntersect(min, rayDir, rayLength, triangle);
                    if (!intersection.isPresent() || !((distance = min.method_1025(intersect = intersection.get())) < closestDistance) && closestDistance != 0.0 || target.method_5668().equals((Object)entity.method_5668()) && closestDistance != 0.0) continue;
                    closestEntity = target;
                    hitPos = intersect;
                    closestDistance = distance;
                }
            }
        }
        if (closestEntity != null) {
            cir.setReturnValue((Object)new class_3966(closestEntity, hitPos));
        } else {
            cir.setReturnValue(null);
        }
        cir.cancel();
    }

    @Unique
    private static List<Vector3f> inflateQuadToBox(List<Vector3f> quad, float thickness) {
        if (quad.size() != 4) {
            throw new IllegalArgumentException("Expected 4 vertices for quad");
        }
        Vector3f a = quad.get(0);
        Vector3f b = quad.get(1);
        Vector3f c = quad.get(2);
        Vector3f ab = new Vector3f(b.x() - a.x(), b.y() - a.y(), b.z() - a.z());
        Vector3f ac = new Vector3f(c.x() - a.x(), c.y() - a.y(), c.z() - a.z());
        Vector3f normal = new Vector3f(ab.y() * ac.z() - ab.z() * ac.y(), ab.z() * ac.x() - ab.x() * ac.z(), ab.x() * ac.y() - ab.y() * ac.x());
        normal.normalize();
        float halfThickness = thickness / 2.0f;
        normal = new Vector3f(normal.x() * halfThickness, normal.y() * halfThickness, normal.z() * halfThickness);
        ArrayList<Vector3f> inflated = new ArrayList<Vector3f>(8);
        for (Vector3f v : quad) {
            Vector3f plus = new Vector3f(v.x() + normal.x(), v.y() + normal.y(), v.z() + normal.z());
            Vector3f minus = new Vector3f(v.x() - normal.x(), v.y() - normal.y(), v.z() - normal.z());
            inflated.add(plus);
            inflated.add(minus);
        }
        return inflated;
    }

    @Unique
    private static List<Vector3f> sortVertices(List<Vector3f> verts) {
        if (verts.size() != 8) {
            throw new IllegalArgumentException("Expected exactly 8 vertices");
        }
        Vector3f center = new Vector3f();
        for (Vector3f v2 : verts) {
            center.add((Vector3fc)v2);
        }
        center.div(8.0f);
        return verts.stream().sorted(Comparator.comparing(v -> v.y < center.y).thenComparing(v -> v.z < center.z).thenComparing(v -> v.x < center.x)).toList();
    }

    @Unique
    private static List<Triangle> buildCubeTriangles(List<Vector3f> verts) {
        int[][] triIndices;
        ArrayList<Triangle> triangles = new ArrayList<Triangle>(12);
        for (int[] indices : triIndices = new int[][]{{0, 1, 4}, {1, 5, 4}, {2, 6, 3}, {3, 6, 7}, {0, 2, 1}, {1, 2, 3}, {4, 5, 6}, {5, 7, 6}, {0, 4, 2}, {2, 4, 6}, {1, 3, 5}, {3, 7, 5}}) {
            triangles.add(new Triangle(verts.get(indices[0]), verts.get(indices[1]), verts.get(indices[2])));
        }
        return triangles;
    }

    @Unique
    private static Optional<class_243> rayTriangleIntersect(class_243 rayOrigin, class_243 rayDir, double rayLength, Triangle triangle) {
        float EPSILON = 0.01f;
        Vector3f v0 = triangle.v0();
        Vector3f v1 = triangle.v1();
        Vector3f v2 = triangle.v2();
        Vector3f edge1 = new Vector3f((Vector3fc)v1).sub((Vector3fc)v0);
        Vector3f edge2 = new Vector3f((Vector3fc)v2).sub((Vector3fc)v0);
        Vector3f rayDirV = new Vector3f((float)rayDir.field_1352, (float)rayDir.field_1351, (float)rayDir.field_1350);
        Vector3f pvec = new Vector3f();
        rayDirV.cross((Vector3fc)edge2, pvec);
        float det = edge1.dot((Vector3fc)pvec);
        if (det > -0.01f && det < 0.01f) {
            return Optional.empty();
        }
        float invDet = 1.0f / det;
        Vector3f tvec = new Vector3f((float)rayOrigin.field_1352 - v0.x, (float)rayOrigin.field_1351 - v0.y, (float)rayOrigin.field_1350 - v0.z);
        float u = tvec.dot((Vector3fc)pvec) * invDet;
        if (u < -0.01f || u > 1.01f) {
            return Optional.empty();
        }
        Vector3f qvec = new Vector3f();
        tvec.cross((Vector3fc)edge1, qvec);
        float v = rayDirV.dot((Vector3fc)qvec) * invDet;
        if (v < -0.01f || u + v > 1.01f) {
            return Optional.empty();
        }
        float t = edge2.dot((Vector3fc)qvec) * invDet;
        if (t < 0.01f || (double)t > rayLength) {
            return Optional.empty();
        }
        return Optional.of(rayOrigin.method_1019(rayDir.method_1021((double)t)));
    }

    @Unique
    private static void showHitboxVertices(class_1937 world, List<Vector3f> verts) {
        for (Vector3f v : verts) {
            world.method_8406((class_2394)class_2398.field_11207, (double)v.x, (double)v.y, (double)v.z, 0.0, 0.0, 0.0);
        }
    }
}

