import React, {
    ForwardedRef,
    useCallback,
    useEffect,
    useImperativeHandle,
    useRef,
} from "react";
import ThreeDotsWave from "../ThreeDotsWave";
import { Box, Stack, SxProps, Theme } from "@mui/material";
import { useVirtualizer } from "@tanstack/react-virtual";
import { isDefined } from "@convin/utils/helper/common.helper";

type Props<T, U extends Record<string, unknown> = Record<string, unknown>> = {
    data?: Array<T>;
    Component: React.FC<
        {
            prevIndexedData: T | null;
            nextIndexedData: T | null;
            data: T;
            index: number;
            onchanged: () => void;
        } & { [key in keyof U]: U[key] }
    >;
    rowSize?: number;
    hasNext?: boolean;
    fetchNext?: () => unknown;
    isFetching?: boolean;
    sx?: SxProps<Theme>;
    isLoading?: boolean;
    overscan?: number;
    enableSnap?: boolean;
    showScrollBar?: boolean;
} & U;

export type VirtualRefType = {
    scrollToIndex: (index: number) => void;
};

function VirtualListInner<
    T,
    U extends Record<string, unknown> = Record<string, unknown>
>(
    {
        data = [],
        Component,
        rowSize = 1000,
        hasNext,
        fetchNext,
        isFetching,
        isLoading,
        sx,
        overscan,
        enableSnap = false,
        showScrollBar = false,
        ...rest
    }: Props<T, U>,
    ref?: ForwardedRef<VirtualRefType | undefined>
) {
    const parentRef = useRef(null);
    const fetchNextCalled = useRef(false); // Track if fetchNext has been called
    const rowVirtualizer = useVirtualizer({
        count: hasNext ? data.length + 1 : data.length,
        getScrollElement: () => parentRef.current,
        estimateSize: useCallback(() => rowSize, [rowSize]),
        ...(isDefined(overscan)
            ? {
                  overscan,
              }
            : {}),
    });

    useImperativeHandle(
        ref,
        () => {
            return {
                scrollToIndex: (index: number) =>
                    rowVirtualizer.scrollToIndex(index, { behavior: "auto" }),
            };
        },
        []
    );

    useEffect(() => {
        const [lastItem] = [...rowVirtualizer.getVirtualItems()].reverse();
        if (!lastItem) {
            return;
        }
        if (!fetchNext) {
            return;
        }

        if (
            lastItem.index === data.length &&
            hasNext &&
            !isFetching &&
            !fetchNextCalled.current
        ) {
            fetchNext();
            fetchNextCalled.current = true; // Set to true after calling fetchNext
        }
    }, [
        rowVirtualizer.getVirtualItems(),
        data.length,
        hasNext,
        isFetching,
        fetchNext,
    ]);

    useEffect(() => {
        // Reset fetchNextCalled when data or fetching status changes
        if (!isFetching) {
            fetchNextCalled.current = false;
        }
    }, [data, isFetching]);

    return (
        <Box
            sx={{
                p: 3,
                ...sx,
            }}
            ref={parentRef}
            className={`w-full h-full overflow-x-hidden overflow-y-auto ${
                showScrollBar ? "scroll-thumb-visible" : ""
            }`}
            style={{
                ...(enableSnap ? { scrollSnapType: "y mandatory" } : {}),
            }}
        >
            <Stack
                style={{
                    height: `${rowVirtualizer.getTotalSize()}px`,
                    width: "100%",
                    position: "relative",
                }}
                role="listbox"
                data-testid="virtual-list-container"
            >
                {rowVirtualizer
                    .getVirtualItems()
                    .slice(0, 50)
                    .map((virtualRow) => {
                        const { index, start, key } = virtualRow;
                        const isLoaderRow = index > data.length - 1;
                        const item = data[index];
                        return (
                            <div
                                key={key}
                                data-index={index}
                                ref={rowVirtualizer.measureElement}
                                style={{
                                    position: "absolute",
                                    top: 0,
                                    left: 0,
                                    width: "100%",
                                    height: `${data[index]}px`,
                                    transform: `translateY(${start}px)`,
                                    ...(enableSnap
                                        ? { scrollSnapAlign: "start" }
                                        : {}),
                                }}
                            >
                                {isLoaderRow ? (
                                    hasNext ? (
                                        <ThreeDotsWave />
                                    ) : (
                                        <></>
                                    )
                                ) : (
                                    <Component
                                        {...{
                                            prevIndexedData: index
                                                ? data[index - 1]
                                                : null,
                                            nextIndexedData:
                                                data.at(index + 1) ?? null,
                                            data: item,
                                            index: index,
                                            onchanged: () => {},
                                            isLoading,
                                            ...(rest as unknown as U),
                                        }}
                                    />
                                )}
                            </div>
                        );
                    })}
            </Stack>
        </Box>
    );
}

function fixedForwardRef<T, P>(
    render: (props: P, ref: React.Ref<T>) => React.ReactNode
): (props: P & React.RefAttributes<T>) => React.ReactNode {
    // eslint-disable-next-line @typescript-eslint/no-explicit-any
    return React.forwardRef(render as any) as any;
}

const VirtualList = fixedForwardRef(VirtualListInner);

export default VirtualList;
