226 lines
8.8 KiB
TypeScript

"use client"
import { useState, useMemo, useEffect, useRef } from "react"
import {
useReactTable,
getCoreRowModel,
flexRender,
type ColumnDef,
type RowSelectionState,
} from "@tanstack/react-table"
import {
Table,
TableHeader,
TableBody,
TableFooter,
TableHead,
TableRow,
TableCell,
} from "@/shared/components/ui/table"
import { Checkbox } from "@/shared/components/ui/checkbox"
import { DataViewProvider } from "./data-view-context"
import type { DataViewProps } from "./types"
import { DataViewPagination } from "./data-view-pagination"
import { Skeleton } from "@/shared/components/ui/skeleton"
export function DataTable<TData>({
columns,
data,
pagination,
sorting = [],
onChange,
isLoading = false,
onRowClick,
slots,
selection,
}: DataViewProps<TData>) {
const rowKeyStr = (selection?.rowKey as string) ?? "id"
// Persisted map of id → original row data across all pages
const persistedMap = useRef<Map<string, TData>>(new Map())
// Current-page selection state that TanStack Table controls
const [rowSelection, setRowSelection] = useState<RowSelectionState>({})
// When the page/data changes, restore selection state for the new page from the map
useEffect(() => {
if (!selection) return
const restored: RowSelectionState = {}
data.forEach((row) => {
const id = String((row as Record<string, unknown>)[rowKeyStr])
if (persistedMap.current.has(id)) restored[id] = true
})
setRowSelection(restored)
}, [data]) // eslint-disable-line react-hooks/exhaustive-deps
const selectionColumn: ColumnDef<TData, unknown> = useMemo(
() => ({
id: "__select__",
header: ({ table }) => (
<Checkbox
checked={
table.getIsAllPageRowsSelected()
? true
: table.getIsSomePageRowsSelected()
? "indeterminate"
: false
}
onCheckedChange={(value) => table.toggleAllPageRowsSelected(!!value)}
aria-label="Select all"
/>
),
cell: ({ row }) => (
<Checkbox
checked={row.getIsSelected()}
onCheckedChange={(value) => row.toggleSelected(!!value)}
aria-label="Select row"
onClick={(e) => e.stopPropagation()}
/>
),
size: 40,
enableSorting: false,
}),
[],
)
const resolvedColumns = useMemo(
() =>
selection
? [selectionColumn, ...(columns as ColumnDef<TData, unknown>[])]
: (columns as ColumnDef<TData, unknown>[]),
[selection, columns, selectionColumn],
)
const table = useReactTable({
data,
columns: resolvedColumns,
getCoreRowModel: getCoreRowModel(),
manualPagination: true,
manualSorting: true,
pageCount: pagination.pageCount,
enableRowSelection: !!selection,
getRowId: selection
? (row) => String((row as Record<string, unknown>)[(selection.rowKey as string) ?? "id"])
: undefined,
state: {
sorting,
pagination: {
pageIndex: pagination.page - 1,
pageSize: pagination.pageSize,
},
rowSelection,
},
onSortingChange: (updater) => {
const next = typeof updater === "function" ? updater(sorting) : updater
onChange?.({ type: "sorting", sorting: next })
},
onPaginationChange: (updater) => {
const current = { pageIndex: pagination.page - 1, pageSize: pagination.pageSize }
const next = typeof updater === "function" ? updater(current) : updater
onChange?.({
type: "pagination",
pagination: {
page: next.pageIndex + 1,
pageSize: next.pageSize,
pageCount: pagination.pageCount,
total: pagination.total,
},
})
},
onRowSelectionChange: (updater) => {
const next = typeof updater === "function" ? updater(rowSelection) : updater
setRowSelection(next)
if (selection) {
// Sync current page into the persisted map
data.forEach((row) => {
const id = String((row as Record<string, unknown>)[rowKeyStr])
if (next[id]) {
persistedMap.current.set(id, row)
} else {
persistedMap.current.delete(id)
}
})
selection.onSelectionChange(Array.from(persistedMap.current.values()))
}
},
})
return (
<DataViewProvider
pagination={pagination}
sorting={sorting}
onChange={onChange}
isLoading={isLoading}
>
<div data-slot="data-view" className="flex flex-col gap-2">
{slots?.actions && (
<div data-slot="data-view-actions">{slots.actions}</div>
)}
<div className="rounded-md border overflow-auto">
<Table className="w-full">
<TableHeader>
{table.getHeaderGroups().map((headerGroup) => (
<TableRow key={headerGroup.id}>
{headerGroup.headers.map((header) => (
<TableHead key={header.id}>
{header.isPlaceholder
? null
: flexRender(
header.column.columnDef.header,
header.getContext(),
)}
</TableHead>
))}
</TableRow>
))}
{slots?.extraHeader}
</TableHeader>
<TableBody>
{isLoading ? (
Array.from({ length: pagination.pageSize }).map((_, i) => (
<TableRow key={`skeleton-${i}`}>
{columns.map((_, j) => (
<TableCell key={`skeleton-${i}-${j}`}>
<Skeleton className="h-10 w-full" />
</TableCell>
))}
</TableRow>
))
) : table.getRowModel().rows.length ? (
table.getRowModel().rows.map((row) => (
<TableRow
key={row.id}
data-state={row.getIsSelected() && "selected"}
className={onRowClick ? "cursor-pointer" : undefined}
onClick={() => onRowClick?.(row.original)}
>
{row.getVisibleCells().map((cell) => (
<TableCell key={cell.id}>
{flexRender(cell.column.columnDef.cell, cell.getContext())}
</TableCell>
))}
</TableRow>
))
) : (
<TableRow>
<TableCell
colSpan={columns.length}
className="h-24 text-center text-muted-foreground"
>
No results.
</TableCell>
</TableRow>
)}
{slots?.extraBody}
</TableBody>
{slots?.footer && (
<TableFooter>{slots.footer}</TableFooter>
)}
</Table>
</div>
<DataViewPagination table={table} />
</div>
</DataViewProvider>
)
}