/*
 * Decompiled with CFR 0.152.
 */
package net.countered.counteredsaccuratehitboxes.mixin.server;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Optional;
import java.util.function.Predicate;
import net.countered.counteredsaccuratehitboxes.util.HitboxAttachment;
import net.countered.counteredsaccuratehitboxes.util.Triangle;
import net.minecraft.core.particles.ParticleOptions;
import net.minecraft.core.particles.ParticleTypes;
import net.minecraft.world.entity.Entity;
import net.minecraft.world.entity.projectile.ProjectileUtil;
import net.minecraft.world.level.Level;
import net.minecraft.world.phys.AABB;
import net.minecraft.world.phys.EntityHitResult;
import net.minecraft.world.phys.Vec3;
import org.joml.Vector3f;
import org.joml.Vector3fc;
import org.spongepowered.asm.mixin.Mixin;
import org.spongepowered.asm.mixin.Unique;
import org.spongepowered.asm.mixin.injection.At;
import org.spongepowered.asm.mixin.injection.Inject;
import org.spongepowered.asm.mixin.injection.callback.CallbackInfoReturnable;

@Mixin(value={ProjectileUtil.class})
public abstract class ModProjectileUtilMixin {
    @Inject(method={"raycast"}, at={@At(value="HEAD")}, cancellable=true)
    private static void injectCustomHitboxRaycast(Entity entity, Vec3 min, Vec3 max, AABB box, Predicate<Entity> predicate, double maxDistance, CallbackInfoReturnable<EntityHitResult> cir) {
        Level world = entity.level();
        double closestDistance = maxDistance;
        Entity closestEntity = null;
        Vec3 hitPos = null;
        Vec3 rayDir = max.subtract(min).normalize();
        double rayLength = min.distanceTo(max);
        for (Entity target : world.getEntities(entity, box, predicate)) {
            List hitboxes = (List)target.getAttached(HitboxAttachment.HITBOXES);
            if (hitboxes == null || hitboxes.isEmpty()) {
                Vec3 intersect;
                double distance;
                AABB fallback = target.getBoundingBox().inflate((double)target.getPickRadius());
                Optional optional = fallback.clip(min, max);
                if (fallback.contains(min)) {
                    double distance2 = 0.0;
                    if (!(distance2 < closestDistance) && closestDistance != 0.0 || target.getRootVehicle().equals((Object)entity.getRootVehicle()) && closestDistance != 0.0) continue;
                    closestEntity = target;
                    hitPos = optional.orElse(min);
                    closestDistance = 0.0;
                    continue;
                }
                if (!optional.isPresent() || !((distance = min.distanceToSqr(intersect = (Vec3)optional.get())) < closestDistance) && closestDistance != 0.0 || target.getRootVehicle().equals((Object)entity.getRootVehicle()) && closestDistance != 0.0) continue;
                closestEntity = target;
                hitPos = intersect;
                closestDistance = distance;
                continue;
            }
            for (List<Vector3f> cubeVerts : hitboxes) {
                if (cubeVerts.size() == 4) {
                    cubeVerts = ModProjectileUtilMixin.inflateQuadToBox(cubeVerts, 0.01f);
                }
                if (cubeVerts.size() != 8) continue;
                List<Vector3f> sortedVerts = ModProjectileUtilMixin.sortVertices(cubeVerts);
                List<Triangle> triangles = ModProjectileUtilMixin.buildCubeTriangles(sortedVerts);
                for (Triangle triangle : triangles) {
                    Vec3 intersect;
                    double distance;
                    Optional<Vec3> intersection = ModProjectileUtilMixin.rayTriangleIntersect(min, rayDir, rayLength, triangle);
                    if (!intersection.isPresent() || !((distance = min.distanceToSqr(intersect = intersection.get())) < closestDistance) && closestDistance != 0.0 || target.getRootVehicle().equals((Object)entity.getRootVehicle()) && closestDistance != 0.0) continue;
                    closestEntity = target;
                    hitPos = intersect;
                    closestDistance = distance;
                }
            }
        }
        if (closestEntity != null) {
            cir.setReturnValue((Object)new EntityHitResult(closestEntity, hitPos));
        } else {
            cir.setReturnValue(null);
        }
        cir.cancel();
    }

    @Unique
    private static List<Vector3f> inflateQuadToBox(List<Vector3f> quad, float thickness) {
        if (quad.size() != 4) {
            throw new IllegalArgumentException("Expected 4 vertices for quad");
        }
        Vector3f a = quad.get(0);
        Vector3f b = quad.get(1);
        Vector3f c = quad.get(2);
        Vector3f ab = new Vector3f(b.x() - a.x(), b.y() - a.y(), b.z() - a.z());
        Vector3f ac = new Vector3f(c.x() - a.x(), c.y() - a.y(), c.z() - a.z());
        Vector3f normal = new Vector3f(ab.y() * ac.z() - ab.z() * ac.y(), ab.z() * ac.x() - ab.x() * ac.z(), ab.x() * ac.y() - ab.y() * ac.x());
        normal.normalize();
        float halfThickness = thickness / 2.0f;
        normal = new Vector3f(normal.x() * halfThickness, normal.y() * halfThickness, normal.z() * halfThickness);
        ArrayList<Vector3f> inflated = new ArrayList<Vector3f>(8);
        for (Vector3f v : quad) {
            Vector3f plus = new Vector3f(v.x() + normal.x(), v.y() + normal.y(), v.z() + normal.z());
            Vector3f minus = new Vector3f(v.x() - normal.x(), v.y() - normal.y(), v.z() - normal.z());
            inflated.add(plus);
            inflated.add(minus);
        }
        return inflated;
    }

    @Unique
    private static List<Vector3f> sortVertices(List<Vector3f> verts) {
        if (verts.size() != 8) {
            throw new IllegalArgumentException("Expected exactly 8 vertices");
        }
        Vector3f center = new Vector3f();
        for (Vector3f v2 : verts) {
            center.add((Vector3fc)v2);
        }
        center.div(8.0f);
        return verts.stream().sorted(Comparator.comparing(v -> v.y < center.y).thenComparing(v -> v.z < center.z).thenComparing(v -> v.x < center.x)).toList();
    }

    @Unique
    private static List<Triangle> buildCubeTriangles(List<Vector3f> verts) {
        int[][] triIndices;
        ArrayList<Triangle> triangles = new ArrayList<Triangle>(12);
        for (int[] indices : triIndices = new int[][]{{0, 1, 4}, {1, 5, 4}, {2, 6, 3}, {3, 6, 7}, {0, 2, 1}, {1, 2, 3}, {4, 5, 6}, {5, 7, 6}, {0, 4, 2}, {2, 4, 6}, {1, 3, 5}, {3, 7, 5}}) {
            triangles.add(new Triangle(verts.get(indices[0]), verts.get(indices[1]), verts.get(indices[2])));
        }
        return triangles;
    }

    @Unique
    private static Optional<Vec3> rayTriangleIntersect(Vec3 rayOrigin, Vec3 rayDir, double rayLength, Triangle triangle) {
        float EPSILON = 0.01f;
        Vector3f v0 = triangle.v0();
        Vector3f v1 = triangle.v1();
        Vector3f v2 = triangle.v2();
        Vector3f edge1 = new Vector3f((Vector3fc)v1).sub((Vector3fc)v0);
        Vector3f edge2 = new Vector3f((Vector3fc)v2).sub((Vector3fc)v0);
        Vector3f rayDirV = new Vector3f((float)rayDir.x, (float)rayDir.y, (float)rayDir.z);
        Vector3f pvec = new Vector3f();
        rayDirV.cross((Vector3fc)edge2, pvec);
        float det = edge1.dot((Vector3fc)pvec);
        if (det > -0.01f && det < 0.01f) {
            return Optional.empty();
        }
        float invDet = 1.0f / det;
        Vector3f tvec = new Vector3f((float)rayOrigin.x - v0.x, (float)rayOrigin.y - v0.y, (float)rayOrigin.z - v0.z);
        float u = tvec.dot((Vector3fc)pvec) * invDet;
        if (u < -0.01f || u > 1.01f) {
            return Optional.empty();
        }
        Vector3f qvec = new Vector3f();
        tvec.cross((Vector3fc)edge1, qvec);
        float v = rayDirV.dot((Vector3fc)qvec) * invDet;
        if (v < -0.01f || u + v > 1.01f) {
            return Optional.empty();
        }
        float t = edge2.dot((Vector3fc)qvec) * invDet;
        if (t < 0.01f || (double)t > rayLength) {
            return Optional.empty();
        }
        return Optional.of(rayOrigin.add(rayDir.scale((double)t)));
    }

    @Unique
    private static void showHitboxVertices(Level world, List<Vector3f> verts) {
        for (Vector3f v : verts) {
            world.addParticle((ParticleOptions)ParticleTypes.END_ROD, (double)v.x, (double)v.y, (double)v.z, 0.0, 0.0, 0.0);
        }
    }
}

