import { createNextState } from '@reduxjs/toolkit';
import { get, isFunction, unset, updateWith } from 'lodash';

export const nodeIdKey = Symbol('');
export const nodePathKey = Symbol('');
export const treeKey = Symbol('');
export const nodeValueKey = Symbol('');

export const makeNode = (value = null) => (
  { [nodeValueKey]: value }
);

export const upsertNode = (setState, nodePath, value) => {
  setState((prev) => {
    const existingValue = get(prev[treeKey], nodePath)?.[nodeValueKey];

    const isValueCallback = isFunction(value);
    const updatedValue = isValueCallback ? value(existingValue) : value;
    if (Object.is(existingValue, updatedValue)) {
      // new value is the same as the old value
      return prev;
    }

    return createNextState(prev, (draft) => {
      updateWith(
        draft[treeKey],
        nodePath,
        (node) => Object.assign(node ?? makeNode(), { [nodeValueKey]: updatedValue }),
        (currValue) => (currValue ?? makeNode()),
      );
    }, undefined);
  });
};

export const removeNode = (setState, nodePath) => {
  setState((prev) => (
    createNextState(prev, (draft) => {
      unset(draft[treeKey], nodePath);
    })
  ));
};

export const dfs = (node, nodeId, callback) => {
  callback({ node, nodeId });
  Object.entries(node)
    .forEach(([nextNodeId, next]) => dfs(next, nextNodeId, callback));
};

export const flattenSubtree = (subtree, subtreeRootNodeId) => {
  const flatNodes = [];
  const nodeIds = new Set();

  dfs(subtree, subtreeRootNodeId, ({ node, nodeId }) => {
    // warn duplicated node ids
    const exists = nodeIds.has(nodeId);
    if (exists) {
      console.warn(`Duplicated nodeId ${nodeId}`);
    }
    nodeIds.add(nodeId);

    flatNodes.push([nodeId, node[nodeValueKey]]);
  });

  return Object.freeze(Object.fromEntries(flatNodes));
};

const flatNodesCache = new WeakMap();
export const getFlatSubtree = (subtree, subtreeRootNodeId) => {
  const cachedFlatNodes = flatNodesCache.get(subtree);
  if (cachedFlatNodes !== undefined) {
    return cachedFlatNodes;
  }

  const flatNodes = flattenSubtree(subtree, subtreeRootNodeId);
  flatNodesCache.set(subtree, flatNodes);
  return flatNodes;
};
