import { AsyncButton, Button } from '@components/button';
import { Form, FormSelect } from '@components/form';
import { useGraphqlPagination } from '@components/infiniteScroll';
import {
  Dataset,
  GetDatasetsQuery,
  GetWorkspaceMlModelsQuery,
  MlModel,
  useCreateMlInferenceMutation,
  useGetDatasetsQuery,
  useGetWorkspaceMlModelsQuery,
} from '@generated/UseGraphqlHooks';
import { MenuItem, Stack } from '@mui/material';
import { useEffect } from 'react';
import * as yup from 'yup';

const validationSchema = yup.object({
  modelId: yup.string().trim().required('Model is required'),
  datasetId: yup.string().trim().required('Dataset is required'),
});

export const CreateInferenceModal = ({
  workspaceId,
  datasetId: defaultDatasetId,
  modelId: defaultModelId,
  onClose,
  refetch,
}: {
  workspaceId: string;
  datasetId?: string;
  modelId?: string;
  onClose: () => void;
  refetch: () => void;
}) => {
  const { data: models, fetchMore: fetchMoreModels } = useGraphqlPagination<
    GetWorkspaceMlModelsQuery,
    MlModel
  >(
    useGetWorkspaceMlModelsQuery({ variables: { workspaceId, limit: 100 } }),
    'getMLModels',
    'modelId',
  );

  useEffect(() => {
    if (defaultModelId && !models.find((model) => model?.modelId === defaultModelId)) {
      fetchMoreModels();
    }
  }, [models]);

  const { data: datasets, fetchMore: fetchMoreDatasets } = useGraphqlPagination<
    GetDatasetsQuery,
    Dataset
  >(useGetDatasetsQuery({ variables: { workspaceId, limit: 100 } }), 'getDatasets', 'datasetId');

  useEffect(() => {
    if (defaultDatasetId && !datasets.find((dataset) => dataset?.datasetId === defaultDatasetId)) {
      fetchMoreDatasets();
    }
  }, [datasets]);

  const [createInference] = useCreateMlInferenceMutation();
  const handleFormSubmit = async ({ datasetId, modelId }) => {
    await createInference({
      variables: {
        workspaceId,
        datasetId,
        modelId,
      },
    });
    if (refetch) {
      refetch();
    }
    onClose();
  };

  return (
    <Form
      initialValues={{ datasetId: defaultDatasetId || '', modelId: defaultModelId || '' }}
      validateOnBlur={false}
      validationSchema={validationSchema}
      onSubmit={handleFormSubmit}
    >
      {({ handleSubmit, isSubmitting }) => (
        <Stack>
          <FormSelect name="modelId" label="Model" fullWidth fetchMore={fetchMoreModels}>
            {models
              .filter((model) => model?.status === 'success')
              .map((model) => (
                <MenuItem key={model?.modelId} value={model?.modelId}>
                  {model?.name}
                </MenuItem>
              ))}
          </FormSelect>
          <FormSelect name="datasetId" label="Dataset" fullWidth fetchMore={fetchMoreDatasets}>
            {datasets.map((dataset) => (
              <MenuItem key={dataset?.datasetId} value={dataset?.datasetId}>
                {dataset?.name}
              </MenuItem>
            ))}
          </FormSelect>
          <Stack gap={4}>
            <AsyncButton
              fullWidth
              loading={isSubmitting}
              disabled={isSubmitting}
              onClick={handleSubmit}
            >
              Create
            </AsyncButton>
            <Button fullWidth variant="secondary" onClick={onClose}>
              Cancel
            </Button>
          </Stack>
        </Stack>
      )}
    </Form>
  );
};
