import { createAsyncThunk, createSlice } from "@reduxjs/toolkit";
import {
  addDoc,
  collection,
  deleteDoc,
  doc,
  updateDoc,
  writeBatch,
} from "firebase/firestore";
import omit from "lodash.omit";
import { db } from "../../config/firebase";
import {
  calculateAnchorCoordinates,
  isShape,
  filterNotNull,
} from "../../utils/diagramUtils";

const initialState = {
  diagramId: null,
  diagram: null,
  diagramLoading: false,
  diagramError: false,
  shapes: {},
  shapesLoading: false,
  shapesError: false,
  connections: {},
  connectionsLoading: false,
  connectionsError: false,
  connectionIdByAnchorShapeId: {},
};

// **************
// DIAGRAMS
// **************
export const updateDiagram = createAsyncThunk(
  "diagram/updateDiagram",
  async ({ diagramId, diagram }) => {
    await updateDoc(doc(db, "diagrams", diagramId), omit(diagram, ["id"]));
  }
);

export const updateElements = createAsyncThunk(
  "diagram/updateElements",
  async ({ diagramId, elements }, { getState }) => {
    const batch = writeBatch(db);
    elements.forEach((element) =>
      batch.update(
        doc(
          db,
          `diagrams/${diagramId}/${
            isShape(element) ? "shapes" : "connections"
          }/${element.id}`
        ),
        omit(element, ["id"])
      )
    );
    elements
      .filter(isShape)
      .flatMap((shape) =>
        getUpdatedConnectionsAfterShapeChange(getState().diagram, {
          payload: { id: shape.id },
        })
      )
      .forEach((connection) => {
        batch.update(
          doc(db, `diagrams/${diagramId}/connections/${connection.id}`),
          omit(connection, ["id"])
        );
      });
    await batch.commit();
  }
);

// **************
// SHAPES
// **************
export const createShape = createAsyncThunk(
  "diagram/createShape",
  async ({ diagramId, shape }) => {
    await addDoc(
      collection(db, `diagrams/${diagramId}/shapes`),
      omit(shape, ["id"])
    );
  }
);

export const deleteShape = createAsyncThunk(
  "diagram/deleteShape",
  async ({ diagramId, shapeId }, { dispatch, getState }) => {
    const connections = selectConnections(getState());
    const connectionsToUpdate = Object.values(connections)
      .filter(
        (connection) =>
          connection?.startAnchor?.shapeId === shapeId ||
          connection?.endAnchor?.shapeId === shapeId
      )
      .map((connection) => {
        const conn = { ...connection };
        if (connection?.startAnchor?.shapeId === shapeId)
          delete conn.startAnchor;
        if (connection?.andAnchor?.shapeId === shapeId) delete conn.startAnchor;
        return conn;
      });
    if (connectionsToUpdate.length > 0) {
      dispatch(updateElements({ diagramId, elements: connectionsToUpdate }));
    }
    deleteDoc(doc(db, `diagrams/${diagramId}/shapes/${shapeId}`));
  }
);

// **************
// CONNECTIONS
// **************
export const createConnection = createAsyncThunk(
  "diagram/createConnection",
  async ({ diagramId, connection }) => {
    await addDoc(
      collection(db, `diagrams/${diagramId}/connections`),
      omit(connection, ["id"])
    );
  }
);

export const deleteConnection = createAsyncThunk(
  "diagram/deleteConnection",
  async ({ diagramId, connectionId }) => {
    deleteDoc(doc(db, `diagrams/${diagramId}/connections/${connectionId}`));
  }
);

const getUpdatedConnectionsAfterShapeChange = (state, { payload: { id } }) =>
  Object.values(state.connections)
    .filter(
      (connection) =>
        connection?.startAnchor?.shapeId === id ||
        connection?.endAnchor?.shapeId === id
    )
    .map((connection) => {
      const startAnchor = connection.startAnchor;
      const endAnchor = connection.endAnchor;
      const startCoords =
        startAnchor && state.shapes[startAnchor.shapeId]
          ? calculateAnchorCoordinates(
              state.shapes[startAnchor.shapeId],
              startAnchor?.position
            )
          : connection.start;
      const endCoords =
        endAnchor && state.shapes[endAnchor.shapeId]
          ? calculateAnchorCoordinates(
              state.shapes[endAnchor.shapeId],
              endAnchor?.position
            )
          : connection.start;
      if (
        (startAnchor &&
          (startCoords.x !== connection.x || startCoords.y !== connection.y)) ||
        (endAnchor &&
          (endCoords.x !== connection.x || endCoords.y !== connection.y))
      ) {
        return {
          id: connection.id,
          ...(startAnchor && {
            start: startCoords,
          }),
          ...(endAnchor && {
            end: endCoords,
          }),
        };
      }
      return null;
    })
    .filter(filterNotNull);

const setShapeReducer = (state, { payload: { id, shape } }) => {
  state.shapes[id] = { id, ...state.shapes[id], ...shape };
  getUpdatedConnectionsAfterShapeChange(state, {
    payload: { id, shape },
  }).forEach(({ id, ...connection }) => {
    state.connections[id] = {
      id,
      ...state.connections[id],
      ...connection,
    };
  });
};

export const counterSlice = createSlice({
  name: "diagram",
  initialState,
  reducers: {
    resetErrors: (state) => {
      state.diagramError = null;
      state.shapesError = null;
      state.connectionsError = null;
    },
    // **************
    // DIAGRAMS
    // **************
    setDiagramId: (state, { payload: id }) => {
      state.diagramId = id;
    },
    setDiagram: (state, { payload: { diagram } }) => {
      state.diagram = { ...state.diagram, ...diagram };
      state.diagramLoading = false;
    },
    setDiagramLoading: (state) => {
      state.diagramLoading = true;
    },
    setDiagramError: (state, { payload: error }) => {
      state.diagramError = error;
    },
    // **************
    // SHAPES
    // **************
    clearShapes: (state, { payload: ids }) => {
      if (ids) ids.forEach((id) => delete state.shapes[id]);
      else state.shapes = {};
    },
    clearShape: (state, { payload: shapeId }) => {
      delete state.shapes[shapeId];
    },
    setShape: setShapeReducer,
    setShapes: (state, { payload: shapes }) => {
      shapes.forEach(({ id, shape }) => {
        setShapeReducer(state, { payload: { id, shape } });
      });
    },
    // TODO: Error handling for all network errors
    setShapesLoading: (state) => {
      state.shapesLoading = true;
    },
    setShapesDoneLoading: (state) => {
      state.shapesLoading = false;
    },
    setShapesError: (state, { payload: error }) => {
      state.shapesError = error;
    },
    // **************
    // CONNECTIONS
    // **************
    clearConnections: (state, { payload: ids }) => {
      if (ids)
        ids.forEach((id) => {
          delete state.connections[id];
        });
      else state.connections = {};
    },
    clearConnection: (state, { payload: connectionId }) => {
      delete state.connections[connectionId];
    },
    // TODO: Error handling for all network errors
    setConnectionsLoading: (state) => {
      state.connectionsLoading = true;
    },
    setConnectionsDoneLoading: (state) => {
      state.connectionsLoading = false;
    },
    setConnectionsError: (state, { payload: error }) => {
      state.connectionsError = error;
    },
    setConnection: (state, { payload: { id, connection } }) => {
      state.connections[id] = { id, ...state.connections[id], ...connection };
    },
    setConnections: (state, { payload: connections }) => {
      connections.forEach(({ id, connection }) => {
        state.connections[id] = { id, ...state.connections[id], ...connection };
      });
    },
  },
});

export const {
  resetErrors,
  setDiagramId,
  setDiagram,
  setDiagramLoading,
  setDiagramError,
  clearShapes,
  clearShape,
  setShapesLoading,
  setShapesDoneLoading,
  setShapesError,
  setShape,
  setShapes,
  clearConnections,
  clearConnection,
  setConnectionsLoading,
  setConnectionsDoneLoading,
  setConnectionsError,
  setConnection,
  setConnections,
} = counterSlice.actions;

// **************
// DIAGRAMS
// **************
export const selectDiagramId = (state) => state.diagram.diagramId;
export const selectDiagram = (state) => state.diagram.diagram;
export const selectDiagramLoading = (state) => state.diagram.diagramLoading;
export const selectDiagramError = (state) => state.diagram.diagramError;

export const selectElementCreator = (elementId) => (state) =>
  selectShape(state, elementId) || selectConnection(state, elementId);

// **************
// SHAPES
// **************
export const selectShapeIds = (state) => Object.keys(state.diagram.shapes);
export const selectShapes = (state) => state.diagram.shapes;
export const selectShape = (state, shapeId) => state.diagram.shapes[shapeId];
export const selectShapesLoading = (state) => state.diagram.shapesLoading;
export const selectShapesError = (state) => state.diagram.shapesError;

// **************
// CONNECTIONS
// **************
export const selectConnectionIds = (state) =>
  Object.keys(state.diagram.connections);
export const selectConnections = (state) => state.diagram.connections;
export const selectConnection = (state, connectionId) =>
  state.diagram.connections[connectionId];
export const selectConnectionsLoading = (state) =>
  state.diagram.connectionsLoading;
export const selectConnectionsError = (state) => state.diagram.connectionsError;

export default counterSlice.reducer;
