import type { Message, User } from "@procision-software/database-zod";
import {
  createContext,
  useCallback,
  useContext,
  useState,
  type PropsWithChildren,
  useEffect,
} from "react";
import { trpc } from "~/utils/trpc";

type MessageThreadContextProviderProps = {
  facilityId?: string;
  messageThreadId?: string;
  practiceId?: string;
  userId: string;
  onCreateMessageThread?: (id: string) => Promise<unknown> | void;
  onCreateMessage?: (id: string) => Promise<unknown> | void;
};

type MessageThreadContextProviderValue = {
  isError: boolean;
  isLoading: boolean;
  messages: (Message & { user: User })[];
  createMessage: (body: string) => Promise<void>;
  isCreatingMessage: boolean;
};

const MessageThreadContext = createContext<MessageThreadContextProviderValue>({
  isError: false,
  isLoading: false,
  messages: [],
  createMessage: () => Promise.resolve(),
  isCreatingMessage: false,
});

export default function MessageThreadContextProvider({
  facilityId,
  messageThreadId,
  practiceId,
  userId,
  children,
  onCreateMessageThread,
  onCreateMessage,
}: PropsWithChildren<MessageThreadContextProviderProps>) {
  const trpcUtils = trpc.useUtils();
  const invalidate = () => {
    void trpcUtils.core.messageThread.caseThreads.invalidate();
  };

  const [internalMessageThreadId, setInternalMessageThreadId] = useState(messageThreadId);
  const { mutateAsync: rpcCreateMessageThread } = trpc.core.messageThread.create.useMutation();
  const { mutateAsync: rpcCreateMessage } = trpc.core.message.create.useMutation({
    onSuccess: invalidate,
  });
  const trpcContext = trpc.useContext();

  useEffect(() => setInternalMessageThreadId(messageThreadId), [messageThreadId]);

  const {
    data: messageThread,
    isError,
    isLoading,
  } = trpc.core.messageThread.get.useQuery({
    id: internalMessageThreadId,
  });

  const createMessageThread = useCallback(async () => {
    const { id } = await rpcCreateMessageThread({ facilityId, practiceId });

    setInternalMessageThreadId(id);

    await trpcContext.core.messageThread.get.invalidate({ id: internalMessageThreadId });

    await onCreateMessageThread?.(id);

    return id;
  }, [
    facilityId,
    internalMessageThreadId,
    onCreateMessageThread,
    practiceId,
    rpcCreateMessageThread,
    trpcContext.core.messageThread.get,
  ]);

  const [isCreatingMessage, setIsCreatingMessage] = useState(false);
  const createMessage = useCallback(
    async (body: string) => {
      setIsCreatingMessage(true);
      const messageThreadId = internalMessageThreadId ?? (await createMessageThread());

      const { id } = await rpcCreateMessage({
        messageThreadId,
        userId,
        body,
      });

      await trpcContext.core.messageThread.get.invalidate({ id: messageThreadId });

      await onCreateMessage?.(id);
      setIsCreatingMessage(false);
    },
    [
      createMessageThread,
      internalMessageThreadId,
      onCreateMessage,
      rpcCreateMessage,
      trpcContext.core.messageThread.get,
      userId,
    ]
  );

  const value = {
    messages: messageThread?.messages ?? [],
    isError,
    isLoading,
    createMessage,
    isCreatingMessage,
  };

  return <MessageThreadContext.Provider value={value}>{children}</MessageThreadContext.Provider>;
}

export function useMessageThreadContext() {
  const context = useContext(MessageThreadContext);

  if (context === null) {
    throw new Error("useMessageThreadContext must be used within a MessageThreadContextProvider");
  }

  return context;
}
