import React, {
  createContext,
  useCallback,
  useContext,
  useEffect,
  useRef,
  useState,
} from "react";
import {
  FileImageType,
  mapAssetsToImagesFiles,
  mapInferencesToImagesFiles,
} from "domains/file-manager/interfaces";
import { useTeamContext } from "domains/teams/contexts/TeamProvider";
import { useUser } from "domains/user/hooks/useUser";
import {
  GetModelsInferencesByModelIdAndInferenceIdApiResponse,
  useGetAssetsQuery,
  useLazyGetModelsInferencesByModelIdAndInferenceIdQuery,
} from "infra/api/generated/api";
import _ from "lodash";

import { skipToken } from "@reduxjs/toolkit/dist/query";

interface SessionProviderProps {
  files: FileImageType[];
  addInference: (
    inference: GetModelsInferencesByModelIdAndInferenceIdApiResponse["inference"]
  ) => void;
  clearSession: () => void;
  deleteImageFiles: (imageFiles: FileImageType[]) => void;
  undeleteImageFiles: (imageFiles: FileImageType[]) => void;
}

export const SessionContext = createContext<SessionProviderProps>({
  files: [],
  addInference: () => {},
  clearSession: () => {},
  deleteImageFiles: () => {},
  undeleteImageFiles: () => {},
});

export function SessionProvider({
  children = <></>,
}: {
  children?: React.ReactNode;
}) {
  const { selectedTeam } = useTeamContext();
  const { nsfwFilteredTypes } = useUser();
  const [inferences, setInferences] = useState<
    GetModelsInferencesByModelIdAndInferenceIdApiResponse["inference"][]
  >([]);
  const [getInferenceTrigger] =
    useLazyGetModelsInferencesByModelIdAndInferenceIdQuery();
  const intervals = useRef<any[]>();
  const [createdAfter, setCreatedAfter] = useState<string>(
    new Date().toISOString()
  );
  const [deletedImageFiles, setDeletedImageFiles] = useState<FileImageType[]>(
    []
  );
  const [shouldRefetch, setShouldRefetch] = useState(true);
  const [files, setFiles] = useState<FileImageType[]>([]);

  useEffect(() => {
    return () => {
      // eslint-disable-next-line react-hooks/exhaustive-deps
      intervals.current?.forEach((interval) => clearInterval(interval));
    };
  }, []);

  const { data } = useGetAssetsQuery(
    !inferences.length
      ? skipToken
      : {
          teamId: selectedTeam.id,
          createdAfter,
          pageSize: "100",
        },
    {
      pollingInterval: shouldRefetch ? 5000 : undefined,
    }
  );

  useEffect(() => {
    let newShouldRefetch = false;
    // group all the assets by inferenceId while also adding the missing assets
    const assetsGroupedByInferenceId = inferences.map((inference) => {
      const assets = (data?.assets ?? []).filter(
        (asset) => asset.metadata.inferenceId === inference.id
      );
      let nbMissingAssets = Math.max(
        inference.parameters.numSamples! - assets.length,
        0
      );
      if (nbMissingAssets > 0) {
        // if there are missing assets, we check if it's because the asset was deleted
        const nbDeletedAssetsFromInference = deletedImageFiles.filter(
          (imageFile) => imageFile.meta.metadata.inferenceId === inference.id
        ).length;
        nbMissingAssets -= nbDeletedAssetsFromInference;
        // if there are still missing assets, and the inference did not fail, we need to refetch
        if (nbMissingAssets > 0 && inference.status !== "failed") {
          newShouldRefetch = true;
        }
      }
      return [
        inference.id,
        [
          ...mapAssetsToImagesFiles(assets, nsfwFilteredTypes),
          ...mapInferencesToImagesFiles([
            {
              ...inference,
              // adding the missing assets
              images: new Array(nbMissingAssets).fill({}).map((_, index) => ({
                id: inference.id + "-" + index + "-placeholder",
                url: "",
                seed: "",
              })),
            },
          ]),
        ],
      ] as [string, FileImageType[]];
    });
    if (newShouldRefetch !== shouldRefetch) {
      setShouldRefetch(newShouldRefetch);
    }
    // now that we have all the assets and they are grouped by inferenceId, we flatten the array
    const assetsFlattened = assetsGroupedByInferenceId
      .map(([, assets]) => assets)
      .flat();
    const deletedAssetIds = deletedImageFiles.map((imageFile) => imageFile.id);
    // we filter out the deleted assets
    const newFiles = assetsFlattened.filter(
      (asset) => !deletedAssetIds.includes(asset.id)
    );
    // if length is not the same or one of the asset changed, we update the files
    if (
      newFiles.length !== files.length ||
      newFiles.some((newFile) => {
        const file = files.find((file) => file.id === newFile.id);
        return (
          !file ||
          !_.isEqual(
            _.omit(file, "meta.metadata.progressPercent"),
            _.omit(newFile, "meta.metadata.progressPercent")
          )
        );
      })
    ) {
      setFiles(newFiles);
    }
  }, [
    deletedImageFiles,
    data,
    inferences,
    shouldRefetch,
    files,
    nsfwFilteredTypes,
  ]);

  const refreshInference = useCallback(
    async (
      inference: GetModelsInferencesByModelIdAndInferenceIdApiResponse["inference"]
    ) => {
      const interval = setInterval(async () => {
        const { data: getInferenceData } = await getInferenceTrigger({
          modelId: inference.modelId,
          inferenceId: inference.id,
          teamId: selectedTeam.id,
        });
        if (!getInferenceData) return;

        const updatedInference = getInferenceData.inference;
        setInferences((inferences) => {
          const inferenceIndex = inferences.findIndex(
            (i) => i.id === inference.id
          );
          if (inferenceIndex === -1) {
            return inferences;
          }
          inferences[inferenceIndex] = { ...updatedInference };
          return [...inferences];
        });

        if (
          updatedInference.status === "failed" ||
          updatedInference.status === "succeeded"
        ) {
          clearInterval(interval);
          intervals.current = intervals.current?.filter((i) => i !== interval);
        }
      }, 5000);
      intervals.current = [...(intervals.current || []), interval];
    },
    [getInferenceTrigger, selectedTeam.id]
  );

  const addInference = useCallback(
    (
      inference: GetModelsInferencesByModelIdAndInferenceIdApiResponse["inference"]
    ) => {
      setInferences((inferences) => [inference, ...inferences]);
      void refreshInference(inference);
    },
    [refreshInference]
  );

  const clearSession = useCallback(async () => {
    setCreatedAfter(new Date().toISOString());
    setInferences([]);
  }, []);

  const deleteImageFiles = useCallback(async (imageFiles: FileImageType[]) => {
    setDeletedImageFiles((currentDeletedImageFiles) => [
      ...currentDeletedImageFiles,
      ...imageFiles,
    ]);
  }, []);

  const undeleteImageFiles = useCallback(
    async (imageFiles: FileImageType[]) => {
      setDeletedImageFiles((currentDeletedImageFiles) =>
        currentDeletedImageFiles.filter(
          (deletedImageFile) =>
            !imageFiles.find(
              (imageFile) => imageFile.id === deletedImageFile.id
            )
        )
      );
    },
    []
  );

  return (
    <SessionContext.Provider
      value={{
        files,
        addInference,
        clearSession,
        deleteImageFiles,
        undeleteImageFiles,
      }}
    >
      {children}
    </SessionContext.Provider>
  );
}

export function useSessionContext() {
  return useContext<SessionProviderProps>(SessionContext);
}
