import copy
from datetime import datetime
import itertools
import random
import time
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from unittest.mock import Mock
from unittest.mock import patch

import pytest

import optuna
from optuna import Study
from optuna._callbacks import RetryFailedTrialCallback
from optuna.distributions import CategoricalDistribution
from optuna.distributions import LogUniformDistribution
from optuna.distributions import UniformDistribution
from optuna.storages import _CachedStorage
from optuna.storages import BaseStorage
from optuna.storages import InMemoryStorage
from optuna.storages import RDBStorage
from optuna.storages import RedisStorage
from optuna.storages._base import DEFAULT_STUDY_NAME_PREFIX
from optuna.study._study_direction import StudyDirection
from optuna.study._study_summary import StudySummary
from optuna.testing.storage import STORAGE_MODES
from optuna.testing.storage import STORAGE_MODES_HEARTBEAT
from optuna.testing.storage import StorageSupplier
from optuna.testing.threading import _TestableThread
from optuna.trial import FrozenTrial
from optuna.trial import TrialState


ALL_STATES = list(TrialState)

EXAMPLE_ATTRS = {
    "dataset": "MNIST",
    "none": None,
    "json_serializable": {"baseline_score": 0.001, "tags": ["image", "classification"]},
}


def test_get_storage() -> None:

    assert isinstance(optuna.storages.get_storage(None), InMemoryStorage)
    assert isinstance(optuna.storages.get_storage("sqlite:///:memory:"), _CachedStorage)
    assert isinstance(
        optuna.storages.get_storage("redis://test_user:passwd@localhost:6379/0"), RedisStorage
    )


@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
def test_create_new_study(storage_mode: str) -> None:

    with StorageSupplier(storage_mode) as storage:

        study_id = storage.create_new_study()

        summaries = storage.get_all_study_summaries()
        assert len(summaries) == 1
        assert summaries[0]._study_id == study_id
        assert summaries[0].study_name.startswith(DEFAULT_STUDY_NAME_PREFIX)

        study_id2 = storage.create_new_study()
        # Study id must be unique.
        assert study_id != study_id2
        summaries = storage.get_all_study_summaries()
        assert len(summaries) == 2
        assert {s._study_id for s in summaries} == {study_id, study_id2}
        assert all(s.study_name.startswith(DEFAULT_STUDY_NAME_PREFIX) for s in summaries)


@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
def test_create_new_study_unique_id(storage_mode: str) -> None:

    with StorageSupplier(storage_mode) as storage:

        study_id = storage.create_new_study()
        study_id2 = storage.create_new_study()
        storage.delete_study(study_id2)
        study_id3 = storage.create_new_study()

        # Study id must not be reused after deletion.
        if not isinstance(storage, (RDBStorage, _CachedStorage)):
            # TODO(ytsmiling) Fix RDBStorage so that it does not reuse study_id.
            assert len({study_id, study_id2, study_id3}) == 3
        summaries = storage.get_all_study_summaries()
        assert {s._study_id for s in summaries} == {study_id, study_id3}


@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
def test_create_new_study_with_name(storage_mode: str) -> None:

    with StorageSupplier(storage_mode) as storage:

        # Generate unique study_name from the current function name and storage_mode.
        function_name = test_create_new_study_with_name.__name__
        study_name = function_name + "/" + storage_mode
        study_id = storage.create_new_study(study_name)

        assert study_name == storage.get_study_name_from_id(study_id)

        with pytest.raises(optuna.exceptions.DuplicatedStudyError):
            storage.create_new_study(study_name)


@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
def test_delete_study(storage_mode: str) -> None:

    with StorageSupplier(storage_mode) as storage:

        study_id = storage.create_new_study()
        storage.create_new_trial(study_id)
        trials = storage.get_all_trials(study_id)
        assert len(trials) == 1

        with pytest.raises(KeyError):
            # Deletion of non-existent study.
            storage.delete_study(study_id + 1)

        storage.delete_study(study_id)
        study_id = storage.create_new_study()
        trials = storage.get_all_trials(study_id)
        assert len(trials) == 0

        storage.delete_study(study_id)
        with pytest.raises(KeyError):
            # Double free.
            storage.delete_study(study_id)


@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
def test_delete_study_after_create_multiple_studies(storage_mode: str) -> None:

    with StorageSupplier(storage_mode) as storage:
        study_id1 = storage.create_new_study()
        study_id2 = storage.create_new_study()
        study_id3 = storage.create_new_study()

        storage.delete_study(study_id2)

        studies = {s._study_id: s for s in storage.get_all_study_summaries()}
        assert study_id1 in studies
        assert study_id2 not in studies
        assert study_id3 in studies


@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
def test_get_study_id_from_name_and_get_study_name_from_id(storage_mode: str) -> None:

    with StorageSupplier(storage_mode) as storage:

        # Generate unique study_name from the current function name and storage_mode.
        function_name = test_get_study_id_from_name_and_get_study_name_from_id.__name__
        study_name = function_name + "/" + storage_mode
        storage = optuna.storages.get_storage(storage)
        study_id = storage.create_new_study(study_name=study_name)

        # Test existing study.
        assert storage.get_study_name_from_id(study_id) == study_name
        assert storage.get_study_id_from_name(study_name) == study_id

        # Test not existing study.
        with pytest.raises(KeyError):
            storage.get_study_id_from_name("dummy-name")

        with pytest.raises(KeyError):
            storage.get_study_name_from_id(study_id + 1)


@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
def test_get_study_id_from_trial_id(storage_mode: str) -> None:

    with StorageSupplier(storage_mode) as storage:

        # Generate unique study_name from the current function name and storage_mode.
        storage = optuna.storages.get_storage(storage)

        # Check if trial_number starts from 0.
        study_id = storage.create_new_study()

        trial_id = storage.create_new_trial(study_id)
        assert storage.get_study_id_from_trial_id(trial_id) == study_id


@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
def test_set_and_get_study_directions(storage_mode: str) -> None:

    with StorageSupplier(storage_mode) as storage:

        for target, opposite in [
            ((StudyDirection.MINIMIZE,), (StudyDirection.MAXIMIZE,)),
            ((StudyDirection.MAXIMIZE,), (StudyDirection.MINIMIZE,)),
            (
                (StudyDirection.MINIMIZE, StudyDirection.MAXIMIZE),
                (StudyDirection.MAXIMIZE, StudyDirection.MINIMIZE),
            ),
            (
                [StudyDirection.MINIMIZE, StudyDirection.MAXIMIZE],
                [StudyDirection.MAXIMIZE, StudyDirection.MINIMIZE],
            ),
        ]:
            study_id = storage.create_new_study()

            def check_set_and_get(directions: Sequence[StudyDirection]) -> None:
                storage.set_study_directions(study_id, directions)
                got_directions = storage.get_study_directions(study_id)

                assert got_directions == list(
                    directions
                ), "Direction of a study should be a tuple of `StudyDirection` objects."

            directions = storage.get_study_directions(study_id)
            assert len(directions) == 1
            assert directions[0] == StudyDirection.NOT_SET

            # Test setting value.
            check_set_and_get(target)

            # Test overwriting value to the same direction.
            storage.set_study_directions(study_id, target)

            # Test overwriting value to the opposite direction.
            with pytest.raises(ValueError):
                storage.set_study_directions(study_id, opposite)

            # Test overwriting value to the not set.
            with pytest.raises(ValueError):
                storage.set_study_directions(study_id, (StudyDirection.NOT_SET,))

            # Test non-existent study.
            with pytest.raises(KeyError):
                storage.set_study_directions(study_id + 1, opposite)

            # Test non-existent study is checked before directions.
            with pytest.raises(KeyError):
                storage.set_study_directions(study_id + 1, (StudyDirection.NOT_SET,))


@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
def test_set_and_get_study_user_attrs(storage_mode: str) -> None:

    with StorageSupplier(storage_mode) as storage:
        study_id = storage.create_new_study()

        def check_set_and_get(key: str, value: Any) -> None:

            storage.set_study_user_attr(study_id, key, value)
            assert storage.get_study_user_attrs(study_id)[key] == value

        # Test setting value.
        for key, value in EXAMPLE_ATTRS.items():
            check_set_and_get(key, value)
        assert storage.get_study_user_attrs(study_id) == EXAMPLE_ATTRS

        # Test overwriting value.
        check_set_and_get("dataset", "ImageNet")

        # Non-existent study id or key.
        non_existent_study_id = study_id + 1
        with pytest.raises(KeyError):
            storage.set_study_user_attr(non_existent_study_id, "key", "value")
        with pytest.raises(KeyError):
            storage.get_study_user_attrs(non_existent_study_id)
        with pytest.raises(KeyError):
            storage.get_study_user_attrs(non_existent_study_id)


@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
def test_set_and_get_study_system_attrs(storage_mode: str) -> None:

    with StorageSupplier(storage_mode) as storage:
        study_id = storage.create_new_study()

        def check_set_and_get(key: str, value: Any) -> None:

            storage.set_study_system_attr(study_id, key, value)
            assert storage.get_study_system_attrs(study_id)[key] == value

        # Test setting value.
        for key, value in EXAMPLE_ATTRS.items():
            check_set_and_get(key, value)
        assert storage.get_study_system_attrs(study_id) == EXAMPLE_ATTRS

        # Test overwriting value.
        check_set_and_get("dataset", "ImageNet")

        # Non-existent study id.
        with pytest.raises(KeyError):
            storage.set_study_system_attr(study_id + 1, "key", "value")


@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
def test_study_user_and_system_attrs_confusion(storage_mode: str) -> None:

    with StorageSupplier(storage_mode) as storage:
        study_id = storage.create_new_study()
        for key, value in EXAMPLE_ATTRS.items():
            storage.set_study_system_attr(study_id, key, value)
        assert storage.get_study_system_attrs(study_id) == EXAMPLE_ATTRS
        assert storage.get_study_user_attrs(study_id) == {}

        study_id = storage.create_new_study()
        for key, value in EXAMPLE_ATTRS.items():
            storage.set_study_user_attr(study_id, key, value)
        assert storage.get_study_user_attrs(study_id) == EXAMPLE_ATTRS
        assert storage.get_study_system_attrs(study_id) == {}


@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
def test_create_new_trial(storage_mode: str) -> None:
    def _check_trials(
        trials: List[FrozenTrial],
        idx: int,
        trial_id: int,
        time_before_creation: datetime,
        time_after_creation: datetime,
    ) -> None:
        assert len(trials) == idx + 1
        assert len({t._trial_id for t in trials}) == idx + 1
        assert trial_id in {t._trial_id for t in trials}
        assert {t.number for t in trials} == set(range(idx + 1))
        assert all(t.state == TrialState.RUNNING for t in trials)
        assert all(t.params == {} for t in trials)
        assert all(t.intermediate_values == {} for t in trials)
        assert all(t.user_attrs == {} for t in trials)
        assert all(t.system_attrs == {} for t in trials)
        assert all(
            t.datetime_start < time_before_creation
            for t in trials
            if t._trial_id != trial_id and t.datetime_start is not None
        )
        assert all(
            time_before_creation < t.datetime_start < time_after_creation
            for t in trials
            if t._trial_id == trial_id and t.datetime_start is not None
        )
        assert all(t.datetime_complete is None for t in trials)
        assert all(t.value is None for t in trials)

    with StorageSupplier(storage_mode) as storage:

        study_id = storage.create_new_study()
        n_trial_in_study = 3
        for i in range(n_trial_in_study):
            time_before_creation = datetime.now()
            trial_id = storage.create_new_trial(study_id)
            time_after_creation = datetime.now()

            trials = storage.get_all_trials(study_id)
            _check_trials(trials, i, trial_id, time_before_creation, time_after_creation)

        # Create trial in non-existent study.
        with pytest.raises(KeyError):
            storage.create_new_trial(study_id + 1)

        study_id2 = storage.create_new_study()
        for i in range(n_trial_in_study):
            storage.create_new_trial(study_id2)

            trials = storage.get_all_trials(study_id2)
            # Check that the offset of trial.number is zero.
            assert {t.number for t in trials} == set(range(i + 1))

        trials = storage.get_all_trials(study_id) + storage.get_all_trials(study_id2)
        # Check trial_ids are unique across studies.
        assert len({t._trial_id for t in trials}) == 2 * n_trial_in_study


@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
def test_create_new_trial_with_template_trial(storage_mode: str) -> None:

    start_time = datetime.now()
    complete_time = datetime.now()
    template_trial = FrozenTrial(
        state=TrialState.COMPLETE,
        value=10000,
        datetime_start=start_time,
        datetime_complete=complete_time,
        params={"x": 0.5},
        distributions={"x": UniformDistribution(0, 1)},
        user_attrs={"foo": "bar"},
        system_attrs={"baz": 123},
        intermediate_values={1: 10, 2: 100, 3: 1000},
        number=55,  # This entry is ignored.
        trial_id=-1,  # dummy value (unused).
    )

    def _check_trials(trials: List[FrozenTrial], idx: int, trial_id: int) -> None:
        assert len(trials) == idx + 1
        assert len({t._trial_id for t in trials}) == idx + 1
        assert trial_id in {t._trial_id for t in trials}
        assert {t.number for t in trials} == set(range(idx + 1))
        assert all(t.state == template_trial.state for t in trials)
        assert all(t.params == template_trial.params for t in trials)
        assert all(t.distributions == template_trial.distributions for t in trials)
        assert all(t.intermediate_values == template_trial.intermediate_values for t in trials)
        assert all(t.user_attrs == template_trial.user_attrs for t in trials)
        assert all(t.system_attrs == template_trial.system_attrs for t in trials)
        assert all(t.datetime_start == template_trial.datetime_start for t in trials)
        assert all(t.datetime_complete == template_trial.datetime_complete for t in trials)
        assert all(t.value == template_trial.value for t in trials)

    with StorageSupplier(storage_mode) as storage:

        study_id = storage.create_new_study()

        n_trial_in_study = 3
        for i in range(n_trial_in_study):
            trial_id = storage.create_new_trial(study_id, template_trial=template_trial)
            trials = storage.get_all_trials(study_id)
            _check_trials(trials, i, trial_id)

        # Create trial in non-existent study.
        with pytest.raises(KeyError):
            storage.create_new_trial(study_id + 1)

        study_id2 = storage.create_new_study()
        for i in range(n_trial_in_study):
            storage.create_new_trial(study_id2, template_trial=template_trial)
            trials = storage.get_all_trials(study_id2)
            assert {t.number for t in trials} == set(range(i + 1))

        trials = storage.get_all_trials(study_id) + storage.get_all_trials(study_id2)
        # Check trial_ids are unique across studies.
        assert len({t._trial_id for t in trials}) == 2 * n_trial_in_study


@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
def test_get_trial_number_from_id(storage_mode: str) -> None:

    with StorageSupplier(storage_mode) as storage:
        storage = optuna.storages.get_storage(storage)

        # Check if trial_number starts from 0.
        study_id = storage.create_new_study()

        trial_id = storage.create_new_trial(study_id)
        assert storage.get_trial_number_from_id(trial_id) == 0

        trial_id = storage.create_new_trial(study_id)
        assert storage.get_trial_number_from_id(trial_id) == 1


@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
def test_set_trial_state(storage_mode: str) -> None:

    with StorageSupplier(storage_mode) as storage:

        study_id = storage.create_new_study()
        trial_ids = [storage.create_new_trial(study_id) for _ in ALL_STATES]

        for trial_id, state in zip(trial_ids, ALL_STATES):
            if state == TrialState.WAITING:
                continue
            assert storage.get_trial(trial_id).state == TrialState.RUNNING
            datetime_start_prev = storage.get_trial(trial_id).datetime_start
            if state.is_finished():
                storage.set_trial_values(trial_id, (0.0,))
            storage.set_trial_state(trial_id, state)
            assert storage.get_trial(trial_id).state == state
            # Repeated state changes to RUNNING should not trigger further datetime_start changes.
            if state == TrialState.RUNNING:
                assert storage.get_trial(trial_id).datetime_start == datetime_start_prev
            if state.is_finished():
                assert storage.get_trial(trial_id).datetime_complete is not None
            else:
                assert storage.get_trial(trial_id).datetime_complete is None

        for state in ALL_STATES:
            if not state.is_finished():
                continue
            trial_id = storage.create_new_trial(study_id)
            storage.set_trial_values(trial_id, (0.0,))
            storage.set_trial_state(trial_id, state)
            for state2 in ALL_STATES:
                # Cannot update states of finished trials.
                with pytest.raises(RuntimeError):
                    storage.set_trial_state(trial_id, state2)


@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
def test_get_trial_param_and_get_trial_params(storage_mode: str) -> None:

    with StorageSupplier(storage_mode) as storage:
        _, study_to_trials = _setup_studies(storage, n_study=2, n_trial=5, seed=1)

        for _, trial_id_to_trial in study_to_trials.items():
            for trial_id, expected_trial in trial_id_to_trial.items():
                assert storage.get_trial_params(trial_id) == expected_trial.params
                for key in expected_trial.params.keys():
                    assert storage.get_trial_param(trial_id, key) == expected_trial.distributions[
                        key
                    ].to_internal_repr(expected_trial.params[key])

        non_existent_trial_id = (
            max(tid for ts in study_to_trials.values() for tid in ts.keys()) + 1
        )
        with pytest.raises(KeyError):
            storage.get_trial_params(non_existent_trial_id)
        with pytest.raises(KeyError):
            storage.get_trial_param(non_existent_trial_id, "paramA")
        existent_trial_id = non_existent_trial_id - 1
        with pytest.raises(KeyError):
            storage.get_trial_param(existent_trial_id, "dummy-key")


@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
def test_set_trial_param(storage_mode: str) -> None:

    with StorageSupplier(storage_mode) as storage:

        # Setup test across multiple studies and trials.
        study_id = storage.create_new_study()
        trial_id_1 = storage.create_new_trial(study_id)
        trial_id_2 = storage.create_new_trial(study_id)
        trial_id_3 = storage.create_new_trial(storage.create_new_study())

        # Setup distributions.
        distribution_x = UniformDistribution(low=1.0, high=2.0)
        distribution_y_1 = CategoricalDistribution(choices=("Shibuya", "Ebisu", "Meguro"))
        distribution_y_2 = CategoricalDistribution(choices=("Shibuya", "Shinsen"))
        distribution_z = LogUniformDistribution(low=1.0, high=100.0)

        # Set new params.
        storage.set_trial_param(trial_id_1, "x", 0.5, distribution_x)
        storage.set_trial_param(trial_id_1, "y", 2, distribution_y_1)
        assert storage.get_trial_param(trial_id_1, "x") == 0.5
        assert storage.get_trial_param(trial_id_1, "y") == 2
        # Check set_param breaks neither get_trial nor get_trial_params.
        assert storage.get_trial(trial_id_1).params == {"x": 0.5, "y": "Meguro"}
        assert storage.get_trial_params(trial_id_1) == {"x": 0.5, "y": "Meguro"}
        # Duplicated registration should overwrite.
        storage.set_trial_param(trial_id_1, "x", 0.6, distribution_x)
        assert storage.get_trial_param(trial_id_1, "x") == 0.6
        assert storage.get_trial(trial_id_1).params == {"x": 0.6, "y": "Meguro"}
        assert storage.get_trial_params(trial_id_1) == {"x": 0.6, "y": "Meguro"}

        # Set params to another trial.
        storage.set_trial_param(trial_id_2, "x", 0.3, distribution_x)
        storage.set_trial_param(trial_id_2, "z", 0.1, distribution_z)
        assert storage.get_trial_param(trial_id_2, "x") == 0.3
        assert storage.get_trial_param(trial_id_2, "z") == 0.1
        assert storage.get_trial(trial_id_2).params == {"x": 0.3, "z": 0.1}
        assert storage.get_trial_params(trial_id_2) == {"x": 0.3, "z": 0.1}

        # Set params with distributions that do not match previous ones.
        with pytest.raises(ValueError):
            storage.set_trial_param(trial_id_2, "x", 0.5, distribution_z)
        with pytest.raises(ValueError):
            storage.set_trial_param(trial_id_2, "y", 0.5, distribution_z)
        # Choices in CategoricalDistribution should match including its order.
        with pytest.raises(ValueError):
            storage.set_trial_param(
                trial_id_2, "y", 2, CategoricalDistribution(choices=("Meguro", "Shibuya", "Ebisu"))
            )

        storage.set_trial_state(trial_id_2, TrialState.COMPLETE)
        # Cannot assign params to finished trial.
        with pytest.raises(RuntimeError):
            storage.set_trial_param(trial_id_2, "y", 2, distribution_y_1)
        # Check the previous call does not change the params.
        with pytest.raises(KeyError):
            storage.get_trial_param(trial_id_2, "y")
        # State should be checked prior to distribution compatibility.
        with pytest.raises(RuntimeError):
            storage.set_trial_param(trial_id_2, "y", 0.4, distribution_z)

        # Set params of trials in a different study.
        storage.set_trial_param(trial_id_3, "y", 1, distribution_y_2)
        assert storage.get_trial_param(trial_id_3, "y") == 1
        assert storage.get_trial(trial_id_3).params == {"y": "Shinsen"}
        assert storage.get_trial_params(trial_id_3) == {"y": "Shinsen"}

        # Set params of non-existent trial.
        non_existent_trial_id = max([trial_id_1, trial_id_2, trial_id_3]) + 1
        with pytest.raises(KeyError):
            storage.set_trial_param(non_existent_trial_id, "x", 0.1, distribution_x)


@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
def test_set_trial_values(storage_mode: str) -> None:

    with StorageSupplier(storage_mode) as storage:

        # Setup test across multiple studies and trials.
        study_id = storage.create_new_study()
        trial_id_1 = storage.create_new_trial(study_id)
        trial_id_2 = storage.create_new_trial(study_id)
        trial_id_3 = storage.create_new_trial(storage.create_new_study())
        trial_id_4 = storage.create_new_trial(study_id)
        trial_id_5 = storage.create_new_trial(study_id)

        # Test setting new value.
        storage.set_trial_values(trial_id_1, (0.5,))
        storage.set_trial_values(trial_id_3, (float("inf"),))
        storage.set_trial_values(trial_id_4, (0.1, 0.2, 0.3))
        storage.set_trial_values(trial_id_5, [0.1, 0.2, 0.3])

        assert storage.get_trial(trial_id_1).value == 0.5
        assert storage.get_trial(trial_id_2).value is None
        assert storage.get_trial(trial_id_3).value == float("inf")
        assert storage.get_trial(trial_id_4).values == [0.1, 0.2, 0.3]
        assert storage.get_trial(trial_id_5).values == [0.1, 0.2, 0.3]

        # Values can be overwritten.
        storage.set_trial_values(trial_id_1, (0.2,))
        assert storage.get_trial(trial_id_1).value == 0.2

        non_existent_trial_id = max(trial_id_1, trial_id_2, trial_id_3, trial_id_4, trial_id_5) + 1
        with pytest.raises(KeyError):
            storage.set_trial_values(non_existent_trial_id, (1,))

        storage.set_trial_state(trial_id_1, TrialState.COMPLETE)
        # Cannot change values of finished trials.
        with pytest.raises(RuntimeError):
            storage.set_trial_values(trial_id_1, (1,))


@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
def test_set_trial_intermediate_value(storage_mode: str) -> None:

    with StorageSupplier(storage_mode) as storage:

        # Setup test across multiple studies and trials.
        study_id = storage.create_new_study()
        trial_id_1 = storage.create_new_trial(study_id)
        trial_id_2 = storage.create_new_trial(study_id)
        trial_id_3 = storage.create_new_trial(storage.create_new_study())

        # Test setting new values.
        storage.set_trial_intermediate_value(trial_id_1, 0, 0.3)
        storage.set_trial_intermediate_value(trial_id_1, 2, 0.4)
        storage.set_trial_intermediate_value(trial_id_3, 0, 0.1)
        storage.set_trial_intermediate_value(trial_id_3, 1, 0.4)
        storage.set_trial_intermediate_value(trial_id_3, 2, 0.5)

        assert storage.get_trial(trial_id_1).intermediate_values == {0: 0.3, 2: 0.4}
        assert storage.get_trial(trial_id_2).intermediate_values == {}
        assert storage.get_trial(trial_id_3).intermediate_values == {0: 0.1, 1: 0.4, 2: 0.5}

        # Test setting existing step.
        storage.set_trial_intermediate_value(trial_id_1, 0, 0.2)
        assert storage.get_trial(trial_id_1).intermediate_values == {0: 0.2, 2: 0.4}

        non_existent_trial_id = max(trial_id_1, trial_id_2, trial_id_3) + 1
        with pytest.raises(KeyError):
            storage.set_trial_intermediate_value(non_existent_trial_id, 0, 0.2)

        storage.set_trial_state(trial_id_1, TrialState.COMPLETE)
        # Cannot change values of finished trials.
        with pytest.raises(RuntimeError):
            storage.set_trial_intermediate_value(trial_id_1, 0, 0.2)


@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
def test_get_trial_user_attrs(storage_mode: str) -> None:

    with StorageSupplier(storage_mode) as storage:
        _, study_to_trials = _setup_studies(storage, n_study=2, n_trial=5, seed=10)
        assert all(
            storage.get_trial_user_attrs(trial_id) == trial.user_attrs
            for trials in study_to_trials.values()
            for trial_id, trial in trials.items()
        )

        non_existent_trial = max(tid for ts in study_to_trials.values() for tid in ts.keys()) + 1
        with pytest.raises(KeyError):
            storage.get_trial_user_attrs(non_existent_trial)


@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
def test_set_trial_user_attr(storage_mode: str) -> None:

    with StorageSupplier(storage_mode) as storage:
        trial_id_1 = storage.create_new_trial(storage.create_new_study())

        def check_set_and_get(trial_id: int, key: str, value: Any) -> None:

            storage.set_trial_user_attr(trial_id, key, value)
            assert storage.get_trial(trial_id).user_attrs[key] == value

        # Test setting value.
        for key, value in EXAMPLE_ATTRS.items():
            check_set_and_get(trial_id_1, key, value)
        assert storage.get_trial(trial_id_1).user_attrs == EXAMPLE_ATTRS

        # Test overwriting value.
        check_set_and_get(trial_id_1, "dataset", "ImageNet")

        # Test another trial.
        trial_id_2 = storage.create_new_trial(storage.create_new_study())
        check_set_and_get(trial_id_2, "baseline_score", 0.001)
        assert len(storage.get_trial(trial_id_2).user_attrs) == 1
        assert storage.get_trial(trial_id_2).user_attrs["baseline_score"] == 0.001

        # Cannot set attributes of non-existent trials.
        non_existent_trial_id = max({trial_id_1, trial_id_2}) + 1
        with pytest.raises(KeyError):
            storage.set_trial_user_attr(non_existent_trial_id, "key", "value")

        # Cannot set attributes of finished trials.
        storage.set_trial_state(trial_id_1, TrialState.COMPLETE)
        with pytest.raises(RuntimeError):
            storage.set_trial_user_attr(trial_id_1, "key", "value")


@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
def test_get_trial_system_attrs(storage_mode: str) -> None:

    with StorageSupplier(storage_mode) as storage:
        _, study_to_trials = _setup_studies(storage, n_study=2, n_trial=5, seed=10)
        assert all(
            storage.get_trial_system_attrs(trial_id) == trial.system_attrs
            for trials in study_to_trials.values()
            for trial_id, trial in trials.items()
        )

        non_existent_trial = max(tid for ts in study_to_trials.values() for tid in ts.keys()) + 1
        with pytest.raises(KeyError):
            storage.get_trial_system_attrs(non_existent_trial)


@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
def test_set_trial_system_attr(storage_mode: str) -> None:

    with StorageSupplier(storage_mode) as storage:
        study_id = storage.create_new_study()
        trial_id_1 = storage.create_new_trial(study_id)

        def check_set_and_get(trial_id: int, key: str, value: Any) -> None:

            storage.set_trial_system_attr(trial_id, key, value)
            assert storage.get_trial_system_attrs(trial_id)[key] == value

        # Test setting value.
        for key, value in EXAMPLE_ATTRS.items():
            check_set_and_get(trial_id_1, key, value)
        system_attrs = storage.get_trial(trial_id_1).system_attrs
        assert system_attrs == EXAMPLE_ATTRS

        # Test overwriting value.
        check_set_and_get(trial_id_1, "dataset", "ImageNet")

        # Test another trial.
        trial_id_2 = storage.create_new_trial(study_id)
        check_set_and_get(trial_id_2, "baseline_score", 0.001)
        system_attrs = storage.get_trial(trial_id_2).system_attrs
        assert system_attrs == {"baseline_score": 0.001}

        # Cannot set attributes of non-existent trials.
        non_existent_trial_id = max({trial_id_1, trial_id_2}) + 1
        with pytest.raises(KeyError):
            storage.set_trial_system_attr(non_existent_trial_id, "key", "value")

        # Cannot set attributes of finished trials.
        storage.set_trial_state(trial_id_1, TrialState.COMPLETE)
        with pytest.raises(RuntimeError):
            storage.set_trial_system_attr(trial_id_1, "key", "value")


@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
def test_get_all_study_summaries(storage_mode: str) -> None:

    with StorageSupplier(storage_mode) as storage:
        expected_summaries, _ = _setup_studies(storage, n_study=10, n_trial=10, seed=46)
        summaries = storage.get_all_study_summaries()
        assert len(summaries) == len(expected_summaries)
        for _, expected_summary in expected_summaries.items():
            summary: Optional[StudySummary] = None
            for s in summaries:
                if s.study_name == expected_summary.study_name:
                    summary = s
                    break
            assert summary is not None
            assert summary.direction == expected_summary.direction
            assert summary.datetime_start == expected_summary.datetime_start
            assert summary.study_name == expected_summary.study_name
            assert summary.n_trials == expected_summary.n_trials
            assert summary.user_attrs == expected_summary.user_attrs
            assert summary.system_attrs == expected_summary.system_attrs
            if expected_summary.best_trial is not None:
                assert summary.best_trial is not None
                assert summary.best_trial == expected_summary.best_trial


@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
def test_get_trial(storage_mode: str) -> None:

    with StorageSupplier(storage_mode) as storage:
        _, study_to_trials = _setup_studies(storage, n_study=2, n_trial=20, seed=47)

        for _, expected_trials in study_to_trials.items():
            for expected_trial in expected_trials.values():
                trial = storage.get_trial(expected_trial._trial_id)
                assert trial == expected_trial

        non_existent_trial_id = (
            max(tid for ts in study_to_trials.values() for tid in ts.keys()) + 1
        )
        with pytest.raises(KeyError):
            storage.get_trial(non_existent_trial_id)


@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
def test_get_all_trials(storage_mode: str) -> None:

    with StorageSupplier(storage_mode) as storage:
        _, study_to_trials = _setup_studies(storage, n_study=2, n_trial=20, seed=48)

        for study_id, expected_trials in study_to_trials.items():
            trials = storage.get_all_trials(study_id)
            for trial in trials:
                expected_trial = expected_trials[trial._trial_id]
                assert trial == expected_trial

        non_existent_study_id = max(study_to_trials.keys()) + 1
        with pytest.raises(KeyError):
            storage.get_all_trials(non_existent_study_id)


@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
def test_get_all_trials_deepcopy_option(storage_mode: str) -> None:

    with StorageSupplier(storage_mode) as storage:
        study_summaries, study_to_trials = _setup_studies(storage, n_study=2, n_trial=5, seed=49)

        for study_id in study_summaries:
            with patch("copy.deepcopy", wraps=copy.deepcopy) as mock_object:
                trials0 = storage.get_all_trials(study_id, deepcopy=True)
                assert mock_object.call_count > 0
                assert len(trials0) == len(study_to_trials[study_id])

            # Check modifying output does not break the internal state of the storage.
            trials0_original = copy.deepcopy(trials0)
            trials0[0].params["x"] = 0.1

            with patch("copy.deepcopy", wraps=copy.deepcopy) as mock_object:
                trials1 = storage.get_all_trials(study_id, deepcopy=False)
                assert mock_object.call_count == 0
                assert trials0_original == trials1


@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
def test_get_all_trials_state_option(storage_mode: str) -> None:

    with StorageSupplier(storage_mode) as storage:
        study_id = storage.create_new_study()
        storage.set_study_directions(study_id, [StudyDirection.MAXIMIZE])
        generator = random.Random(51)

        states = (
            TrialState.COMPLETE,
            TrialState.COMPLETE,
            TrialState.PRUNED,
        )

        for state in states:
            t = _generate_trial(generator)
            t.state = state
            storage.create_new_trial(study_id, template_trial=t)

        trials = storage.get_all_trials(study_id, states=None)
        assert len(trials) == 3

        trials = storage.get_all_trials(study_id, states=(TrialState.COMPLETE,))
        assert len(trials) == 2
        assert all(t.state == TrialState.COMPLETE for t in trials)

        trials = storage.get_all_trials(study_id, states=(TrialState.COMPLETE, TrialState.PRUNED))
        assert len(trials) == 3
        assert all(t.state in (TrialState.COMPLETE, TrialState.PRUNED) for t in trials)

        trials = storage.get_all_trials(study_id, states=())
        assert len(trials) == 0

        other_states = [
            s for s in ALL_STATES if s != TrialState.COMPLETE and s != TrialState.PRUNED
        ]
        for state in other_states:
            trials = storage.get_all_trials(study_id, states=(state,))
            assert len(trials) == 0


@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
def test_get_n_trials(storage_mode: str) -> None:

    with StorageSupplier(storage_mode) as storage:
        study_id_to_summaries, _ = _setup_studies(storage, n_study=2, n_trial=7, seed=50)
        for study_id in study_id_to_summaries:
            assert storage.get_n_trials(study_id) == 7

        non_existent_study_id = max(study_id_to_summaries.keys()) + 1
        with pytest.raises(KeyError):
            assert storage.get_n_trials(non_existent_study_id)


@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
def test_get_n_trials_state_option(storage_mode: str) -> None:

    with StorageSupplier(storage_mode) as storage:
        study_id = storage.create_new_study()
        storage.set_study_directions(study_id, (StudyDirection.MAXIMIZE,))
        generator = random.Random(51)

        states = [
            TrialState.COMPLETE,
            TrialState.COMPLETE,
            TrialState.PRUNED,
        ]

        for s in states:
            t = _generate_trial(generator)
            t.state = s
            storage.create_new_trial(study_id, template_trial=t)

        assert storage.get_n_trials(study_id, TrialState.COMPLETE) == 2
        assert storage.get_n_trials(study_id, TrialState.PRUNED) == 1

        other_states = [
            s for s in ALL_STATES if s != TrialState.COMPLETE and s != TrialState.PRUNED
        ]
        for s in other_states:
            assert storage.get_n_trials(study_id, s) == 0


@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
def test_get_best_trial(storage_mode: str) -> None:

    with StorageSupplier(storage_mode) as storage:
        study_id = storage.create_new_study()
        with pytest.raises(ValueError):
            storage.get_best_trial(study_id)

        with pytest.raises(KeyError):
            storage.get_best_trial(study_id + 1)

        storage.set_study_directions(study_id, (StudyDirection.MAXIMIZE,))
        generator = random.Random(51)
        for i in range(3):
            template_trial = _generate_trial(generator)
            template_trial.state = TrialState.COMPLETE
            template_trial.value = float(i)
            storage.create_new_trial(study_id, template_trial=template_trial)
        assert storage.get_best_trial(study_id).number == i


def _setup_studies(
    storage: BaseStorage,
    n_study: int,
    n_trial: int,
    seed: int,
    direction: Optional[StudyDirection] = None,
) -> Tuple[Dict[int, StudySummary], Dict[int, Dict[int, FrozenTrial]]]:
    generator = random.Random(seed)
    study_id_to_summary: Dict[int, StudySummary] = {}
    study_id_to_trials: Dict[int, Dict[int, FrozenTrial]] = {}
    for i in range(n_study):
        study_name = "test-study-name-{}".format(i)
        study_id = storage.create_new_study(study_name=study_name)
        if direction is None:
            direction = generator.choice([StudyDirection.MINIMIZE, StudyDirection.MAXIMIZE])
        storage.set_study_directions(study_id, (direction,))
        best_trial = None
        trials = {}
        datetime_start = None
        for j in range(n_trial):
            trial = _generate_trial(generator)
            trial.number = j
            trial._trial_id = storage.create_new_trial(study_id, trial)
            trials[trial._trial_id] = trial
            if datetime_start is None:
                datetime_start = trial.datetime_start
            else:
                datetime_start = min(datetime_start, trial.datetime_start)
            if trial.state == TrialState.COMPLETE and trial.value is not None:
                if best_trial is None:
                    best_trial = trial
                else:
                    if direction == StudyDirection.MINIMIZE and trial.value < best_trial.value:
                        best_trial = trial
                    elif direction == StudyDirection.MAXIMIZE and best_trial.value < trial.value:
                        best_trial = trial
        study_id_to_trials[study_id] = trials
        study_id_to_summary[study_id] = StudySummary(
            study_name=study_name,
            direction=direction,
            best_trial=best_trial,
            user_attrs={},
            system_attrs={},
            n_trials=len(trials),
            datetime_start=datetime_start,
            study_id=study_id,
        )
    return study_id_to_summary, study_id_to_trials


def _generate_trial(generator: random.Random) -> FrozenTrial:
    example_params = {
        "paramA": (generator.uniform(0, 1), UniformDistribution(0, 1)),
        "paramB": (generator.uniform(1, 2), LogUniformDistribution(1, 2)),
        "paramC": (
            generator.choice(["CatA", "CatB", "CatC"]),
            CategoricalDistribution(("CatA", "CatB", "CatC")),
        ),
        "paramD": (generator.uniform(-3, 0), UniformDistribution(-3, 0)),
        "paramE": (generator.choice([0.1, 0.2]), CategoricalDistribution((0.1, 0.2))),
    }
    example_attrs = {
        "attrA": "valueA",
        "attrB": 1,
        "attrC": None,
        "attrD": {"baseline_score": 0.001, "tags": ["image", "classification"]},
    }
    state = generator.choice(ALL_STATES)
    params = {}
    distributions = {}
    user_attrs = {}
    system_attrs = {}
    intermediate_values = {}
    for key, (value, dist) in example_params.items():
        if generator.choice([True, False]):
            params[key] = value
            distributions[key] = dist
    for key, value in example_attrs.items():
        if generator.choice([True, False]):
            user_attrs["usr_" + key] = value
        if generator.choice([True, False]):
            system_attrs["sys_" + key] = value
    for i in range(generator.randint(4, 10)):
        if generator.choice([True, False]):
            intermediate_values[i] = generator.uniform(-10, 10)
    return FrozenTrial(
        number=0,  # dummy
        state=state,
        value=generator.uniform(-10, 10),
        datetime_start=datetime.now(),
        datetime_complete=datetime.now() if state.is_finished() else None,
        params=params,
        distributions=distributions,
        user_attrs=user_attrs,
        system_attrs=system_attrs,
        intermediate_values=intermediate_values,
        trial_id=0,  # dummy
    )


@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
def test_get_best_trial_for_multi_objective_optimization(storage_mode: str) -> None:

    with StorageSupplier(storage_mode) as storage:
        study_id = storage.create_new_study()

        storage.set_study_directions(study_id, (StudyDirection.MAXIMIZE, StudyDirection.MINIMIZE))
        generator = random.Random(51)
        for i in range(3):
            template_trial = _generate_trial(generator)
            template_trial.state = TrialState.COMPLETE
            template_trial.values = [i, i + 1]
            storage.create_new_trial(study_id, template_trial=template_trial)
        with pytest.raises(ValueError):
            storage.get_best_trial(study_id)


@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
def test_get_trial_id_from_study_id_trial_number(storage_mode: str) -> None:

    with StorageSupplier(storage_mode) as storage:
        with pytest.raises(KeyError):  # Matching study does not exist.
            storage.get_trial_id_from_study_id_trial_number(study_id=0, trial_number=0)

        study_id = storage.create_new_study()

        with pytest.raises(KeyError):  # Matching trial does not exist.
            storage.get_trial_id_from_study_id_trial_number(study_id, trial_number=0)

        trial_id = storage.create_new_trial(study_id)

        assert trial_id == storage.get_trial_id_from_study_id_trial_number(
            study_id, trial_number=0
        )

        # Trial IDs are globally unique within a storage but numbers are only unique within a
        # study. Create a second study within the same storage.
        study_id = storage.create_new_study()

        trial_id = storage.create_new_trial(study_id)

        assert trial_id == storage.get_trial_id_from_study_id_trial_number(
            study_id, trial_number=0
        )


@pytest.mark.parametrize("storage_mode", STORAGE_MODES_HEARTBEAT)
def test_fail_stale_trials_with_optimize(storage_mode: str) -> None:

    heartbeat_interval = 1
    grace_period = 2

    with StorageSupplier(
        storage_mode, heartbeat_interval=heartbeat_interval, grace_period=grace_period
    ) as storage:
        assert storage.is_heartbeat_enabled()

        study1 = optuna.create_study(storage=storage)
        study2 = optuna.create_study(storage=storage)

        trial1 = study1.ask()
        trial2 = study2.ask()
        storage.record_heartbeat(trial1._trial_id)
        storage.record_heartbeat(trial2._trial_id)
        time.sleep(grace_period + 1)

        assert study1.trials[0].state is TrialState.RUNNING
        assert study2.trials[0].state is TrialState.RUNNING

        # Exceptions raised in spawned threads are caught by `_TestableThread`.
        with patch("optuna.study._optimize.Thread", _TestableThread):
            study1.optimize(lambda _: 1.0, n_trials=1)

        assert study1.trials[0].state is TrialState.FAIL
        assert study2.trials[0].state is TrialState.RUNNING


@pytest.mark.parametrize("storage_mode", STORAGE_MODES_HEARTBEAT)
def test_invalid_heartbeat_interval_and_grace_period(storage_mode: str) -> None:

    with pytest.raises(ValueError):
        with StorageSupplier(storage_mode, heartbeat_interval=-1):
            pass

    with pytest.raises(ValueError):
        with StorageSupplier(storage_mode, grace_period=-1):
            pass


@pytest.mark.parametrize("storage_mode", STORAGE_MODES_HEARTBEAT)
def test_failed_trial_callback(storage_mode: str) -> None:
    heartbeat_interval = 1
    grace_period = 2

    def _failed_trial_callback(study: Study, trial: FrozenTrial) -> None:
        assert study.system_attrs["test"] == "A"
        assert trial.system_attrs["test"] == "B"

    failed_trial_callback = Mock(wraps=_failed_trial_callback)

    with StorageSupplier(
        storage_mode,
        heartbeat_interval=heartbeat_interval,
        grace_period=grace_period,
        failed_trial_callback=failed_trial_callback,
    ) as storage:
        assert storage.is_heartbeat_enabled()

        study = optuna.create_study(storage=storage)
        study.set_system_attr("test", "A")

        trial = study.ask()
        trial.set_system_attr("test", "B")
        storage.record_heartbeat(trial._trial_id)
        time.sleep(grace_period + 1)

        # Exceptions raised in spawned threads are caught by `_TestableThread`.
        with patch("optuna.study._optimize.Thread", _TestableThread):
            study.optimize(lambda _: 1.0, n_trials=1)
            failed_trial_callback.assert_called_once()


@pytest.mark.parametrize(
    "storage_mode,max_retry", itertools.product(STORAGE_MODES_HEARTBEAT, [None, 0, 1])
)
def test_retry_failed_trial_callback(storage_mode: str, max_retry: Optional[int]) -> None:
    heartbeat_interval = 1
    grace_period = 2

    with StorageSupplier(
        storage_mode,
        heartbeat_interval=heartbeat_interval,
        grace_period=grace_period,
        failed_trial_callback=RetryFailedTrialCallback(max_retry=max_retry),
    ) as storage:
        assert storage.is_heartbeat_enabled()

        study = optuna.create_study(storage=storage)

        trial = study.ask()
        storage.record_heartbeat(trial._trial_id)
        time.sleep(grace_period + 1)

        # Exceptions raised in spawned threads are caught by `_TestableThread`.
        with patch("optuna.study._optimize.Thread", _TestableThread):
            study.optimize(lambda _: 1.0, n_trials=1)

        # Test the last trial to see if it was a retry of the first trial or not.
        # Test max_retry=None to see if trial is retried.
        # Test max_retry=0 to see if no trials are retried.
        # Test max_retry=1 to see if trial is retried.
        assert RetryFailedTrialCallback.retried_trial_number(study.trials[1]) == (
            None if max_retry == 0 else 0
        )


def test_fail_stale_trials() -> None:
    heartbeat_interval = 1
    grace_period = 2

    def failed_trial_callback(study: "optuna.Study", trial: FrozenTrial) -> None:
        assert study.system_attrs["test"] == "A"
        assert trial.system_attrs["test"] == "B"

    with StorageSupplier("sqlite") as storage:
        assert isinstance(storage, RDBStorage)
        storage.heartbeat_interval = heartbeat_interval
        storage.grace_period = grace_period
        storage.failed_trial_callback = failed_trial_callback
        study = optuna.create_study(storage=storage)
        study.set_system_attr("test", "A")

        trial = study.ask()
        trial.set_system_attr("test", "B")
        storage.record_heartbeat(trial._trial_id)
        time.sleep(grace_period + 1)

        assert study.trials[0].state is TrialState.RUNNING

        optuna.storages.fail_stale_trials(study)

        assert study.trials[0].state is TrialState.FAIL
