import { useMemo } from "react";
import { Node, Edge } from "reactflow";

import { flextree } from "d3-flextree";
import { HierarchyNode, HierarchyPointNode, stratify } from "d3-hierarchy";

import { ExpandCollapseNode } from "~/components/explorer/explorerTypes";

export type UseExpandCollapseOptions = {
  layoutNodes?: boolean;
  treeWidth?: number;
  treeHeight?: number;
};

function isHierarchyPointNode(
  pointNode:
    | HierarchyNode<ExpandCollapseNode>
    | HierarchyPointNode<ExpandCollapseNode>,
): pointNode is HierarchyPointNode<ExpandCollapseNode> {
  return (
    typeof (pointNode as HierarchyPointNode<ExpandCollapseNode>).x ===
      "number" &&
    typeof (pointNode as HierarchyPointNode<ExpandCollapseNode>).y === "number"
  );
}

function useExpandCollapse(
  nodes: Node[],
  edges: Edge[],
  {
    layoutNodes = true,
    treeWidth = 220,
    treeHeight = 100,
  }: UseExpandCollapseOptions = {},
): { nodes: Node[]; edges: Edge[] } {
  return useMemo(() => {
    const hierarchy = stratify<ExpandCollapseNode>()
      .id((d) => d.id)
      .parentId(
        (d: Node) => edges.find((e: Edge) => e.target === d.id)?.source,
      )(nodes);

    hierarchy.descendants().forEach((d) => {
      d.data.data.expandable = !!d.children?.length;

      if (d.data.data.activities) {
        d.data.data.size = [
          treeWidth * 2,
          d.data.data.activities.length * (203 / 2),
        ];
      }

      d.children = d.data.data.expanded ? d.children : undefined;
    });

    const layout = flextree<ExpandCollapseNode>({})
      .nodeSize((node) =>
        node.data?.data.size
          ? node.data.data.size
          : [treeWidth + 20, treeHeight],
      )
      .spacing((node) => (node.data.data.expandable ? 1 : 30));

    // const layout = tree<ExpandCollapseNode>()
    //   .nodeSize([treeWidth, treeHeight])
    //   .separation((node) => (node.data.data.expandable ? 1 : 0.89));

    const root = layoutNodes ? layout(hierarchy) : hierarchy;

    return {
      nodes: root.descendants().map((d) => ({
        ...d.data,
        // This bit is super important! We *mutated* the object in the `forEach`
        // above so the reference is the same. React needs to see a new reference
        // to trigger a re-render of the node.
        data: { ...d.data.data },
        type: d.data.type ?? "curriculum",
        position: isHierarchyPointNode(d)
          ? {
              // X and Y are flipped here to rotate the tree horizontally
              x: d.y,
              y: getYPosition(d),
            }
          : d.data.position,
      })),
      edges: edges.filter(
        (edge) =>
          root.find((h) => h.id === edge.source) &&
          root.find((h) => h.id === edge.target),
      ),
    };
  }, [nodes, edges, layoutNodes, treeWidth, treeHeight]);
}

const getYPosition = (node: HierarchyPointNode<ExpandCollapseNode>) => {
  if (node.data.data.detail) return node.x;
  else return node.x - 24;
};

export default useExpandCollapse;
