import React, { FunctionComponent, useMemo, useState } from 'react'
import { localPoint } from '@visx/event'
import { Group } from '@visx/group'
import {
  hierarchy,
  Treemap as VxTreemap,
  treemapSquarify
} from '@visx/hierarchy'
import { Legend, LegendItem, LegendLabel } from '@visx/legend'
import ParentSize from '@visx/responsive/lib/components/ParentSize'
import { scaleLinear } from '@visx/scale'
import { useTooltip, useTooltipInPortal } from '@visx/tooltip'
import { interpolateSpectral } from 'd3-scale-chromatic'
import { filter, maxBy, minBy, sumBy, uniqBy } from 'lodash'
import {
  Box,
  FormControlLabel,
  makeStyles,
  Paper,
  Radio,
  RadioGroup,
  Typography
} from '@material-ui/core'
import Toolbar from '../Toolbar'

const useStyles = makeStyles(theme => ({
  appGraph: {
    display: 'flex',
    flex: 1,
    overflow: 'hidden'
  },
  legend: {
    lineHeight: '0.9em',
    fontSize: '10px',
    fontFamily: theme.typography.fontFamily,
    padding: '10px 10px',
    margin: '5px 5px',
    display: 'flex',
    justifyContent: 'center'
  }
}))

enum Mode {
  Count = 'Count',
  Value = 'Value'
}

type TooltipData = {
  name?: string,
  value?: number
}

const defaultMargin = { top: 0, right: 0, bottom: 0, left: 0 }

interface Props {
  data: Record<string, unknown>[],
  width: number,
  height: number,
  margin?: { top: number; right: number; bottom: number; left: number },
  nameProperty: string,
  renderValue?: (value: number) => string,
  title?: string,
  valueProperty: string
}

const Treemap: FunctionComponent<Props> = (props: Props) => {
  const {
    data,
    width,
    height,
    margin = defaultMargin,
    nameProperty,
    renderValue,
    title,
    valueProperty
  } = props
  const classes = useStyles()
  const {
    hideTooltip,
    showTooltip,
    tooltipData,
    tooltipLeft,
    tooltipOpen,
    tooltipTop
  } = useTooltip<TooltipData>()
  const [mode, setMode] = useState<Mode>(Mode.Count)
  const { containerRef, TooltipInPortal } = useTooltipInPortal({ scroll: true })

  const xMax = width - margin.left - margin.right
  const yMax = height - margin.top - margin.bottom

  const root = useMemo(() => {
    const unique = uniqBy(data, nameProperty)
    const formatted = unique.map((d) => ({
      id: String(d[nameProperty]),
      name: d[nameProperty],
      value: mode === Mode.Count
        ? filter(data, [nameProperty, d[nameProperty]]).length
        : sumBy(filter(data, [nameProperty, d[nameProperty]]), valueProperty)
    }))
    const sorted = formatted.sort((a, b) => b.value - a.value)
    const nodes = hierarchy({
      children: sorted,
      name: 'Root',
      value: mode === Mode.Count
        ? data.length
        : sumBy(data, valueProperty)
    })
    return nodes
  }, [data, mode, nameProperty, valueProperty])

  const colourScale = useMemo(() =>
    scaleLinear({
      domain: [
        minBy(root.children, 'value')?.value || 0,
        maxBy(root.children, 'value')?.value || 0
      ],
      range: [0, 1]
    }), [root.children])

  const ramp = (width: number, steps: number) => {
    const rects = []
    const xOffsets = []
    for (let i = 0; i < steps + 1; i++) {
      xOffsets.push(i * (width / steps))
    }
    for (let i = 0; i < steps; i++) {
      rects.push({
        width: xOffsets[i + 1] - xOffsets[i] + 1,
        x: xOffsets[i],
        fill: interpolateSpectral(i / (steps - 1))
      })
    }
    return rects
  }

  const handleModeChange = (_event: React.ChangeEvent<HTMLInputElement>, value: string) => {
    setMode(value as Mode)
  }

  return (
    <>
      <Toolbar title={title}/>
      <Box marginLeft={2}>
        <RadioGroup value={mode} onChange={handleModeChange} row>
          <FormControlLabel value={Mode.Count} control={<Radio/>} label="Count"/>
          <FormControlLabel value={Mode.Value} control={<Radio/>} label={valueProperty}/>
        </RadioGroup>
      </Box>
      <div className={classes.legend}>
        <Legend scale={colourScale}>
          { labels => (
            <Box
              display="flex"
              flexDirection="row"
            >
              <LegendItem>
                <LegendLabel flex="stretch" align="right" margin="0 4px">
                  { (mode === Mode.Count || !renderValue) &&
                    labels[0].text
                  }
                  { mode === Mode.Value && renderValue &&
                    renderValue(Number(labels[0].text))
                  }
                </LegendLabel>
                <svg height={15} width={300}>
                  {ramp(300, 512).map((params, index) => (
                    <rect
                      key={`legend_scale_rect_${index}`}
                      height={15}
                      width={params.width}
                      x={params.x}
                      fill={params.fill}
                      strokeWidth={0}
                    />
                  ))}
                </svg>
                <LegendLabel flex="stretch" align="left" margin="0 4px">
                  { (mode === Mode.Count || !renderValue) &&
                    labels[1].text
                  }
                  { mode === Mode.Value && renderValue &&
                    renderValue(Number(labels[1].text))
                  }
                </LegendLabel>
              </LegendItem>
            </Box>
          )}
        </Legend>
      </div>
      <svg
        ref={containerRef}
        width={width}
        height={height}
      >
        <VxTreemap
          top={margin.top}
          root={root}
          size={[xMax, yMax]}
          tile={treemapSquarify}
          round
        >
          { treemap => {
            return (
              <Group>
                { treemap
                  .descendants()
                  .map((node, i) => {
                    const nodeWidth = node.x1 - node.x0
                    const nodeHeight = node.y1 - node.y0
                    return (
                      <Group
                        key={`node-${i}`}
                        top={node.y0 + margin.top}
                        left={node.x0 + margin.top}
                      >
                        { node.depth === 0 &&
                          <rect
                            width={nodeWidth}
                            height={nodeHeight}
                            stroke={'#000000'}
                            fill="transparent"
                          />
                        }
                        { node.depth > 0 &&
                          <Group
                            onMouseMove={event => {
                              const eventSvgCoords = localPoint(event)
                              showTooltip({
                                tooltipData: {
                                  name: node.data.name,
                                  value: node.value
                                },
                                tooltipTop: eventSvgCoords?.y,
                                tooltipLeft: eventSvgCoords?.x
                              })
                            }}
                            onMouseLeave={() => hideTooltip()}
                          >
                            <rect
                              width={nodeWidth}
                              height={nodeHeight}
                              stroke="#000000"
                              fill={interpolateSpectral(
                                colourScale(node.value || 0) || 0
                              )}
                            />
                            <text
                              x={5}
                              y={16}
                              fill="#000000"
                            >
                              {node.data.name}
                            </text>
                          </Group>
                        }
                      </Group>
                    )
                  })
                }
              </Group>
            )
          }}
        </VxTreemap>
      </svg>
      { tooltipOpen && tooltipData && (
        <TooltipInPortal left={tooltipLeft} top={tooltipTop}>
          <Typography component="div">
            <Box fontWeight="fontWeightBold">
              {tooltipData.name}
            </Box>
            <Box>
              { mode === Mode.Count &&
                `Count: ${tooltipData.value}`
              }
              {
                mode === Mode.Value && !renderValue &&
                `${valueProperty}: ${tooltipData.value}`
              }
              {
                mode === Mode.Value && renderValue && tooltipData.value &&
                `${valueProperty}: ${renderValue(tooltipData.value)}`
              }
            </Box>
          </Typography>
        </TooltipInPortal>
      )}
    </>
  )
}

interface ResponsiveProps extends Omit<Props, 'height' | 'width'> {
  width?: number,
  height?: number
}

const ResponsiveTreemap: FunctionComponent<ResponsiveProps> =
(props: ResponsiveProps) => {
  const {
    height,
    width
  } = props
  const classes = useStyles()

  return (
    <Paper className={classes.appGraph}>
      <ParentSize>
        {({ width: visWidth, height: visHeight }) => (
          <Treemap
            {...props}
            width={width ?? visWidth}
            height={height ?? visHeight}
          />
        )}
      </ParentSize>
    </Paper>
  )
}

export default ResponsiveTreemap
