import {
  type ColumnDef,
  type SortingState,
  flexRender,
  getCoreRowModel,
  getFacetedRowModel,
  getFacetedUniqueValues,
  getFilteredRowModel,
  getPaginationRowModel,
  getSortedRowModel,
  useReactTable,
} from '@tanstack/react-table';
import clsx from 'clsx';
import React from 'react';
import { useState } from 'react';
import { twMerge } from 'tailwind-merge';

/* ---------------------------------- Component --------------------------------- */

const TableRoot = React.forwardRef<HTMLTableElement, React.HTMLAttributes<HTMLTableElement>>(
  ({ className, ...props }, ref) => (
    <div className="relative w-full overflow-auto">
      <table
        ref={ref}
        className={twMerge(clsx('w-full caption-bottom overflow-hidden rounded-radius bg-background', className))}
        {...props}
      />
    </div>
  )
);
TableRoot.displayName = 'Table';

/* ---------------------------------- Component --------------------------------- */

const TableHeader = React.forwardRef<HTMLTableSectionElement, React.HTMLAttributes<HTMLTableSectionElement>>(
  ({ className, ...props }, ref) => <thead ref={ref} className={className} {...props} />
);
TableHeader.displayName = 'TableHeader';

/* ---------------------------------- Component --------------------------------- */

const TableBody = React.forwardRef<HTMLTableSectionElement, React.HTMLAttributes<HTMLTableSectionElement>>(
  ({ className, ...props }, ref) => <tbody ref={ref} className={className} {...props} />
);
TableBody.displayName = 'TableBody';

/* ---------------------------------- Component --------------------------------- */

const TableFooter = React.forwardRef<HTMLTableSectionElement, React.HTMLAttributes<HTMLTableSectionElement>>(
  ({ className, ...props }, ref) => (
    <tfoot ref={ref} className={twMerge(clsx('border-t border-stroke-subtle font-medium', className))} {...props} />
  )
);
TableFooter.displayName = 'TableFooter';

/* ---------------------------------- Component --------------------------------- */

const TableRow = React.forwardRef<HTMLTableRowElement, React.HTMLAttributes<HTMLTableRowElement>>(
  ({ className, ...props }, ref) => (
    <tr
      ref={ref}
      className={twMerge(
        clsx(
          'border-stroke-subtle transition-colors last:border-b-0 data-[state=selected]:bg-container-high ',
          className
        )
      )}
      {...props}
    />
  )
);
TableRow.displayName = 'TableRow';

/* ---------------------------------- Component --------------------------------- */

const TableHead = React.forwardRef<
  HTMLTableCellElement,
  React.ThHTMLAttributes<HTMLTableCellElement> & { rounded?: boolean }
>(({ className, rounded, ...props }, ref) => (
  <th
    ref={ref}
    className={twMerge(
      clsx(
        'px-5 py-3 text-left align-middle text-sm font-normal text-secondary bg-container [&:has([role=checkbox])]:pr-0 [&>[role=checkbox]]:translate-y-[2px]',
        rounded && 'first:rounded-bl-radius last:rounded-br-radius',
        className
      )
    )}
    {...props}
  />
));
TableHead.displayName = 'TableHead';

/* ---------------------------------- Component --------------------------------- */

const TableCell = React.forwardRef<HTMLTableCellElement, React.TdHTMLAttributes<HTMLTableCellElement>>(
  ({ className, ...props }, ref) => (
    <td
      ref={ref}
      className={twMerge(
        clsx(
          'px-5 py-3 align-middle text-sm text-foreground [&:has([role=checkbox])]:pr-0 [&>[role=checkbox]]:translate-y-[2px]',
          className
        )
      )}
      {...props}
    />
  )
);
TableCell.displayName = 'TableCell';

/* ---------------------------------- Component --------------------------------- */

const TableCaption = React.forwardRef<HTMLTableCaptionElement, React.HTMLAttributes<HTMLTableCaptionElement>>(
  ({ className, ...props }, ref) => (
    <caption ref={ref} className={clsx('py-1 text-sm text-secondary', className)} {...props} />
  )
);
TableCaption.displayName = 'TableCaption';

/* ---------------------------------- Type --------------------------------- */

export interface TableProps<TData, TValue> {
  /** Table columns */
  columns: ColumnDef<TData, TValue>[];
  /** Table data */
  data: TData[];
  /** Table footer */
  footer?: React.ReactNode;
  /** Table caption */
  caption?: React.ReactNode;
  /** Adds a border around the table
   * @default false
   */
  bordered?: boolean;
  className?: string;
}

/* ---------------------------------- Component --------------------------------- */
function Table<TData, TValue>({
  bordered = false,
  columns,
  data,
  footer,
  caption,
  className,
}: TableProps<TData, TValue>) {
  const [sorting, setSorting] = useState<SortingState>([]);
  const [rowSelection, setRowSelection] = useState({});

  const table = useReactTable({
    data,
    columns,
    getCoreRowModel: getCoreRowModel(),
    getFilteredRowModel: getFilteredRowModel(),
    getPaginationRowModel: getPaginationRowModel(),
    getSortedRowModel: getSortedRowModel(),
    getFacetedRowModel: getFacetedRowModel(),
    getFacetedUniqueValues: getFacetedUniqueValues(),
    onSortingChange: setSorting,
    onRowSelectionChange: setRowSelection,
    state: {
      sorting,
      rowSelection,
    },
  });

  return (
    <div className={clsx('w-full rounded-radius', bordered && 'border border-stroke overflow-hidden', className)}>
      <TableRoot className={'w-full'}>
        {caption && <TableCaption>{caption}</TableCaption>}
        {columns.some((column) => !!column.header) && (
          <TableHeader>
            {table.getHeaderGroups().map((headerGroup) => (
              <TableRow key={headerGroup.id}>
                {headerGroup.headers.map((header) => {
                  return (
                    <TableHead key={header.id} colSpan={header.colSpan} rounded={!bordered}>
                      {header.isPlaceholder ? null : flexRender(header.column.columnDef.header, header.getContext())}
                    </TableHead>
                  );
                })}
              </TableRow>
            ))}
          </TableHeader>
        )}
        <TableBody>
          {table.getRowModel().rows?.length ? (
            table.getRowModel().rows.map((row) => (
              <TableRow
                key={row.id}
                data-state={row.getIsSelected() && 'selected'}
                className={clsx('hover:bg-container', row.getIsSelected() && 'bg-container-high')}
              >
                {row.getVisibleCells().map((cell) => (
                  <TableCell key={cell.id}>{flexRender(cell.column.columnDef.cell, cell.getContext())}</TableCell>
                ))}
              </TableRow>
            ))
          ) : (
            <TableRow>
              <TableCell colSpan={columns.length} className="text-center">
                No results.
              </TableCell>
            </TableRow>
          )}
        </TableBody>
        {footer && (
          <TableFooter>
            <TableRow>
              <TableCell colSpan={columns.length}>{footer}</TableCell>
            </TableRow>
          </TableFooter>
        )}
      </TableRoot>
    </div>
  );
}
Table.displayName = 'Table';

export default Table;

export { TableBody, TableCaption, TableCell, TableFooter, TableHead, TableHeader, TableRoot, TableRow };
