import type {
  BillingCase,
  BillingCharge,
  BillingClaim,
  BillingPayer,
  BillingTransaction,
  BillingTransactionAllocation,
} from "@prisma/client";
import {
  BillingCaseStatus,
  BillingClaimStatus,
  BillingPayerType,
  BillingTransactionStatus,
  type Prisma,
  type PrismaClient,
} from "@prisma/client";
import logger from "@procision-software/asimov/logger";
import { type AppAbility, type OrganizationId } from "@procision-software/auth";
import { DateTime } from "luxon";
import { z } from "zod";
import { searchPatients } from "~/models/Patient";
import { selfPayOverPayments } from "../report/SelfPayOverPayments";
import { X12837PaymentFormatError, processBillingClaim } from "../repositories/x12-837";
import type { BillingClaimsListInputSchema } from "../server/trpc/router/claims";
import {
  billingCaseIdToString,
  needToApplyPrePayment,
  toBillingCaseId,
  type BillingCaseId,
} from "./case";
import { getChargeMasterForId } from "./chargemaster";
import { isOfficeOnlyCharge } from "./charges";
import { applyPrePayment } from "./claim/applyprepayment";
import { billingPayerIdToString, payerName, toBillingPayerId, type BillingPayerId } from "./payers";
import { adjustments, payments, priorPayersAllocations, sumAllocations } from "./payment";

// Define a unique symbol
declare const BillingClaimIdTag: unique symbol;

// Create a tagged type
export type BillingClaimId = string & { readonly tag: typeof BillingClaimIdTag };

// Function to tag a string
export function toBillingClaimId(id: string): BillingClaimId {
  return id as BillingClaimId;
}

export function billingClaimIdToString(id: BillingClaimId): string {
  return id.toString();
}

export class PayerNotWaystarEnabled extends Error {
  constructor(message?: string) {
    super(`Payer not waystar enabled: ${message}`);
  }
}

export async function processClaim(
  prisma: PrismaClient,
  ability: AppAbility,
  id: BillingClaimId,
  status: BillingClaimStatus
) {
  const billingClaim = await prisma.billingClaim.findFirstOrThrow({
    where: {
      id: billingClaimIdToString(id),
    },
    include: {
      billingCase: {
        include: {
          case: true,
        },
      },
    },
  });

  const updatedBillingClaim = await prisma.billingClaim.update({
    where: {
      id: billingClaim.id,
    },
    data: {
      status,
    },
  });

  if (status === "Billed") {
    try {
      await processBillingClaim(prisma, ability, billingClaim.id);
    } catch (error) {
      logger.error(error);

      /**
       * If the error is not a X12837PaymentFormatError, restore the original status, otherwise,
       * the claim will be stuck in "Billed" status and will not be processed again.
       *
       * This is useful for payment methods that do not go through Stedi and Waystar. This keeps
       * the flow of the claim consistent with the rest of the system.
       */
      if (!(error instanceof X12837PaymentFormatError)) {
        await prisma.billingClaim.update({
          where: {
            id: billingClaim.id,
          },
          data: {
            status: billingClaim.status, // restore original status
          },
        });

        throw error;
      }
    }
  }

  if (
    status === BillingClaimStatus.Billed &&
    billingClaim.billingCase.status !== BillingCaseStatus.Billed
  ) {
    await prisma.billingCase.update({
      where: { id: billingClaim.billingCaseId },
      data: { status: BillingCaseStatus.Billed },
    });
  }

  if (await needToApplyPrePayment(billingClaim.billingCase.case)) {
    const selfPayClaim = await findOrCreateSelfPayClaim(
      prisma,
      toBillingCaseId(billingClaim.billingCaseId)
    );
    await applyPrePayment(prisma, ability, toBillingClaimId(selfPayClaim.id));
  }

  return updatedBillingClaim;
}

export class NoNextPayerError extends Error {
  constructor() {
    super("No next payer found");
  }
}

export async function computeClaimBilledAmount(
  prisma: PrismaClient,
  ability: AppAbility,
  id: BillingClaimId
) {
  const claim = await prisma.billingClaim.findFirstOrThrow({
    where: {
      id: billingClaimIdToString(id),
    },
    include: {
      billingCase: {
        include: {
          billingCharges: true,
        },
      },
    },
  });
  const billedAmount =
    claim.billingCase.billingCharges.length > 0
      ? claim.billingCase.billingCharges.reduce((acc, charge) => {
          return acc + charge.billedAmount;
        }, 0)
      : 0;
  return billedAmount;
}

type BillingClaimListInput = z.infer<typeof BillingClaimsListInputSchema>;

export async function claimList(
  prisma: PrismaClient,
  ability: AppAbility,
  input: BillingClaimListInput,
  organizationId: OrganizationId
) {
  const { page, perPage, filter } = input;
  const billingClaimWhereInput: Prisma.BillingClaimWhereInput = {};
  const billingCaseWhereInput: Prisma.BillingCaseWhereInput = {};
  const caseWhereInput: Prisma.CaseWhereInput = {};
  if (input.filter?.billingCaseId) {
    billingClaimWhereInput.billingCaseId = billingCaseIdToString(input.filter.billingCaseId);
  }
  if (input.filter.patientName) {
    const patients = await searchPatients(
      prisma,
      input.filter.patientName,
      { page: 1, perPage: 100 },
      false
    );
    caseWhereInput.patientId = {
      in: patients.rows.map((patient) => patient.id),
    };
  }
  if (input.filter.caseNumber && /\d+/.test(input.filter.caseNumber)) {
    caseWhereInput.financialReference = parseInt(input.filter.caseNumber);
  }
  // Add date filter if available
  if (input.filter?.dateOfService) {
    caseWhereInput.surgeryDate = {
      gte: input.filter.dateOfService.start,
      lte: input.filter.dateOfService.end,
    };
  }
  if (filter.claimStatus && filter.claimStatus.length > 0) {
    billingClaimWhereInput.status = { in: filter.claimStatus };
  }
  if (filter.claimNumber) {
    billingClaimWhereInput.referenceNumber = parseInt(filter.claimNumber);
  }
  if (filter.billingPayerType) {
    billingClaimWhereInput.billingPayer = { paymentType: { in: filter.billingPayerType } };
  }
  if (filter.billingCaseStatus && filter.billingCaseStatus.length > 0) {
    billingCaseWhereInput.status = { in: filter.billingCaseStatus };
  }
  if (filter.lastSubmittedDate) {
    type When = "never" | "on" | "before" | "more-recent-then";
    const startOfDay = DateTime.fromJSDate(filter.lastSubmittedDate).startOf("day").toJSDate();
    const endOfDay = DateTime.fromJSDate(filter.lastSubmittedDate).endOf("day").toJSDate();
    const conditions: Record<When, Prisma.BillingClaimWhereInput["lastSubmittedAt"]> = {
      never: {
        equals: null,
      },
      on: {
        gte: startOfDay,
        lt: endOfDay,
      },
      before: {
        lt: startOfDay,
      },
      "more-recent-then": {
        gt: endOfDay,
      },
    };
    billingClaimWhereInput.lastSubmittedAt = conditions[filter.lastSubmittedTerm ?? "on"];
  } else if (filter.lastSubmittedTerm === "never") {
    billingClaimWhereInput.lastSubmittedAt = {
      equals: null,
    };
  }
  const caseIds = (
    await prisma.case.findMany({
      where: caseWhereInput,
      select: {
        id: true,
      },
    })
  ).map((c) => c.id);
  const billingCaseIds = (
    await prisma.billingCase.findMany({
      where: {
        ...billingClaimWhereInput.billingCase,
        ...billingCaseWhereInput,
        caseId: {
          in: caseIds,
        },
      },
    })
  ).map((c) => c.id);
  if (billingClaimWhereInput.billingCaseId) {
    const existingBillingClaimWhereInputBillCaseId = billingClaimWhereInput.billingCaseId;
    delete billingClaimWhereInput.billingCaseId;
    billingClaimWhereInput.AND = billingClaimWhereInput.AND ?? [];
    if (!("length" in billingClaimWhereInput.AND))
      billingClaimWhereInput.AND = [billingClaimWhereInput.AND];
    billingClaimWhereInput.AND.push({
      billingCaseId: existingBillingClaimWhereInputBillCaseId,
    });
    billingClaimWhereInput.AND.push({
      billingCaseId: {
        in: billingCaseIds,
      },
    });
  } else {
    billingClaimWhereInput.billingCaseId = {
      in: billingCaseIds,
    };
  }

  if (input.filter.onlyBillableCases) {
    const HANDLE_UP_TO = 1000;
    const overPaid = await selfPayOverPayments(prisma, ability, organizationId, {
      page: 1,
      perPage: HANDLE_UP_TO,
    });
    if (overPaid.pagination.all <= HANDLE_UP_TO) {
      billingClaimWhereInput.id = {
        notIn: overPaid.rows.map((claim) => claim.id),
      };
    }
  }

  const claims = await prisma.billingClaim.findMany({
    where: billingClaimWhereInput,
    skip: (page - 1) * perPage,
    take: perPage,
    include: {
      billingPayer: {
        include: {
          waystarInsuranceProvider: true,
        },
      },
      billingCase: {
        include: {
          billingPayers: {
            // just primary
            skip: 0,
            take: 1,
            orderBy: {
              sequenceNumber: "asc",
            },
          },
          billingCharges: {
            include: {
              allocations: {
                include: {
                  billingTransaction: true,
                  billingClaim: {
                    include: {
                      billingPayer: true,
                    },
                  },
                },
              },
            },
          },
          case: true,
        },
      },
    },
    orderBy: {
      createdAt: "desc",
    },
  });
  const all = await prisma.billingClaim.count({
    where: {
      ...billingClaimWhereInput,
    },
  });
  return {
    rows: await Promise.all(
      claims.map(async (claim) => {
        const waystar = claim.billingPayer.waystarInsuranceProvider;
        const allocations = claim.billingCase.billingCharges.flatMap((c) => c.allocations);
        const claimAmounts = claimBillingAmounts(claim, claim.billingCase, allocations);
        const primaryPayer = claim.billingCase.billingPayers[0];
        return {
          ...claim,
          ...claimAmounts,
          allocations: allocations.filter((a) => a.billingClaimId === claim.id),
          sequenceNumber: claim.billingPayer.sequenceNumber,
          payer: {
            ...claim.billingPayer,
            name: await payerName(prisma, ability, claim.billingPayer),
          },
          ...(primaryPayer && {
            primaryPayer: {
              ...primaryPayer,
              name: await payerName(prisma, ability, primaryPayer.id),
            },
          }),
          claimFormat: waystar?.institutionalClaimsAvailable ? "Electronic" : "Paper",
          claimType: waystar ? "UB-04" : "Offline",
          caseReferenceNumber: claim.billingCase.case.financialReference,
          dateOfService: claim.billingCase.case.surgeryDate,
        };
      })
    ),
    pagination: {
      page,
      perPage,
      all,
    },
  };
}

export const canDeleteClaim = (claim: {
  status: BillingClaimStatus;
  lastSubmittedAt: Date | null | undefined;
  allocations: Record<string, unknown>[];
}) => {
  if (claim.status === BillingClaimStatus.Done) return false;
  if (claim.lastSubmittedAt) return false;
  return claim.allocations.length === 0;
};

export async function isSelfPay(
  prisma: PrismaClient,
  ability: AppAbility,
  claim: { billingPayerId: string }
): Promise<boolean> {
  const payer = await prisma.billingPayer.findFirstOrThrow({
    where: {
      id: claim.billingPayerId,
    },
  });
  return payer.paymentType === BillingPayerType.Self_Pay;
}

/**
 * creates a claim. Maintains a reference number for the claim if it is the first claim for the payer.
 * @param prisma
 * @param input
 */
export async function createBillingClaim(
  prisma: PrismaClient,
  input: {
    billingCaseId: BillingCaseId;
    billingPayerId: BillingPayerId;
    frequencyCode: string;
    status?: BillingClaimStatus;
  }
) {
  const payer = await prisma.billingPayer.findFirstOrThrow({
    where: {
      id: billingPayerIdToString(input.billingPayerId),
      billingCaseId: billingCaseIdToString(input.billingCaseId),
    },
    include: {
      billingCase: {
        include: {
          billingClaims: true,
          case: true,
        },
      },
    },
  });
  const hasExisting = payer.billingCase.billingClaims.some(
    (claim) => claim.billingPayerId === payer.id
  );
  return await prisma.billingClaim.create({
    data: {
      frequencyCode: input.frequencyCode,
      status: input.status ?? BillingClaimStatus.New,
      referenceNumber: hasExisting ? undefined : payer.billingCase.case.financialReference,
      billingCase: {
        connect: {
          id: input.billingCaseId,
        },
      },
      billingPayer: {
        connect: {
          id: payer.id,
        },
      },
    },
  });
}

export async function findOrCreateSelfPayClaim(
  prisma: PrismaClient,
  billingCaseId: BillingCaseId
): Promise<BillingClaim> {
  const billingCase = await prisma.billingCase.findFirstOrThrow({
    where: {
      id: billingCaseId,
    },
    include: {
      billingPayers: true,
      billingClaims: {
        include: {
          billingPayer: true,
        },
      },
    },
  });
  const selfPayClaim = billingCase.billingClaims.find((claim) =>
    isClaimOfType(claim, BillingPayerType.Self_Pay)
  );
  if (selfPayClaim) {
    return selfPayClaim;
  } else {
    const selfPayPayer = billingCase.billingPayers.find(
      (payer) => payer.paymentType === BillingPayerType.Self_Pay
    );
    if (!selfPayPayer) {
      throw new Error("Self pay payer not found");
    }
    return await createBillingClaim(prisma, {
      billingCaseId,
      billingPayerId: toBillingPayerId(selfPayPayer.id),
      frequencyCode: "1",
      status: BillingClaimStatus.New,
    });
  }
}

export function isClaimOfType(
  claim: Partial<BillingClaim> & { billingPayer: Pick<BillingPayer, "paymentType"> },
  type: BillingPayerType
) {
  return claim.billingPayer.paymentType === type;
}

export async function getBillingClaim(
  prisma: PrismaClient,
  ability: AppAbility,
  id: BillingClaimId
) {
  const claim = await prisma.billingClaim.findFirstOrThrow({
    where: {
      id: billingClaimIdToString(id),
    },
    include: {
      billingPayer: true,
      transactions: true,
      billingCase: {
        include: {
          billingCharges: {
            include: {
              billingChargeMaster: true,
            },
            orderBy: {
              sequenceNumber: "asc",
            },
          },
        },
      },
    },
  });

  return {
    ...claim,
    billingPayer: {
      ...claim.billingPayer,
      payerName: await payerName(prisma, ability, claim.billingPayerId),
    },
  };
}

export type BillingClaimWithPayer = BillingClaim & { billingPayer: BillingPayer };
export type BillingCaseWithCharges = BillingCase & { billingCharges: BillingCharge[] };
export type BillingTransactionAllocationWithClaims = BillingTransactionAllocation & {
  billingTransaction: BillingTransaction;
  billingClaim: BillingClaimWithPayer;
};

export const BillingAmountsSchema = z.object({
  billedAmount: z.number(),
  priorPayersAmount: z.number(),
  expectedAmount: z.number(),
  adjustmentAmount: z.number(),
  paymentAmount: z.number(),
  outstandingAmount: z.number(),
});
/**
 * The calculated billing amounts for a claim or charge.
 * @prop {number} billedAmount - full amount charged for the case across all payers
 * @prop {number} priorPayersAmount - the total amount of payments and adjustment made by prior payers in the sequence
 * @prop {number} expectedAmount - amount that the current payer is responsible for (billedAmount - priorPayersAmount)
 * @prop {number} adjustmentAmount - adjustments that the current payer has made
 * @prop {number} paymentAmount - payments that the current payer has made
 * @prop {number} outstandingAmount - amount the current payer still owes (billedAmount - priorPayersAmount - adjustmentAmount - paymentAmount)
 */
export type BillingAmounts = z.infer<typeof BillingAmountsSchema>;

type BillingChargeWithAmounts = BillingCharge & BillingAmounts;

export function claimBillingAmounts(
  claim: BillingClaimWithPayer,
  billingCase: BillingCaseWithCharges,
  allocations: BillingTransactionAllocationWithClaims[]
): BillingAmounts & { officeOnlyCharges: BillingChargeWithAmounts[] } {
  const charges = claimBillingAmountsByCharge(claim, billingCase, allocations);

  const officeOnlyCharges = charges.filter((charge) => isOfficeOnlyCharge(charge));

  const amounts = charges
    .filter((charge) => !isOfficeOnlyCharge(charge))
    .reduce(
      (acc, charge) =>
        ({
          billedAmount: acc.billedAmount + charge.billedAmount,
          priorPayersAmount: acc.priorPayersAmount + charge.priorPayersAmount,
          expectedAmount: acc.expectedAmount + charge.expectedAmount,
          adjustmentAmount: acc.adjustmentAmount + charge.adjustmentAmount,
          paymentAmount: acc.paymentAmount + charge.paymentAmount,
          outstandingAmount: acc.outstandingAmount + charge.outstandingAmount,
        }) satisfies BillingAmounts,
      {
        billedAmount: 0,
        priorPayersAmount: 0,
        expectedAmount: 0,
        adjustmentAmount: 0,
        paymentAmount: 0,
        outstandingAmount: 0,
      } satisfies BillingAmounts
    );

  return {
    ...amounts,
    officeOnlyCharges,
  };
}

export function claimBillingAmountsByCharge(
  claim: BillingClaimWithPayer,
  billingCase: BillingCaseWithCharges,
  allocations: BillingTransactionAllocationWithClaims[]
): BillingChargeWithAmounts[] {
  return billingCase.billingCharges
    .filter((charge) => charge.payerTypes.includes(claim.billingPayer.paymentType))
    .map((charge) => {
      const billedAmount = charge.billedAmount;

      const chargeAllocations = allocations.filter(
        (a) =>
          a.billingChargeId === charge.id &&
          a.billingTransaction.status === BillingTransactionStatus.Complete
      );

      const priorPayersAmount = sumAllocations(
        priorPayersAllocations(chargeAllocations, claim.billingPayer.sequenceNumber)
      );

      const adjustmentAmount = sumAllocations(
        adjustments({
          allocations: chargeAllocations.filter(
            (a) => a.billingClaim.billingPayerId === claim.billingPayerId
          ),
        })
      );

      const paymentAmount = sumAllocations(
        payments({
          allocations: chargeAllocations.filter(
            (a) => a.billingClaim.billingPayerId === claim.billingPayerId
          ),
        })
      );

      const expectedAmount = billedAmount - priorPayersAmount;

      const outstandingAmount = expectedAmount - adjustmentAmount - paymentAmount;

      return {
        ...charge,
        billedAmount,
        priorPayersAmount,
        expectedAmount,
        adjustmentAmount,
        paymentAmount,
        outstandingAmount,
      } satisfies BillingChargeWithAmounts;
    });
}

export async function augmentClaimWithBillingAmounts<TClaim extends BillingClaimWithPayer>(
  prisma: PrismaClient,
  claim: TClaim
) {
  const billingCase: BillingCaseWithCharges = await prisma.billingCase.findFirstOrThrow({
    where: { id: claim.billingCaseId },
    include: {
      billingCharges: true,
    },
  });

  const allocations: BillingTransactionAllocationWithClaims[] =
    await prisma.billingTransactionAllocation.findMany({
      where: {
        billingClaim: { billingCaseId: billingCase.id },
        billingTransaction: { status: BillingTransactionStatus.Complete },
      },
      include: {
        billingTransaction: true,
        billingClaim: {
          include: {
            billingPayer: true,
          },
        },
      },
    });

  return {
    ...claim,
    ...claimBillingAmounts(claim, billingCase, allocations),
    allocations: allocations.filter((a) => a.billingClaimId === claim.id),
    charges: await Promise.all(
      claimBillingAmountsByCharge(claim, billingCase, allocations).map(async (charge) => {
        const chargeMaster = await getChargeMasterForId(prisma, charge.billingChargeMasterId);
        return {
          ...charge,
          name: chargeMaster.name,
          code: chargeMaster.cptCode ?? chargeMaster.hcpcsCode,
        };
      })
    ),
  };
}
