import pytest

from ..serial_executor import SerialExecutor  # Replace 'your_module' with the actual module name.


def sample_task(x):
    return x * 2


def sample_failing_task():
    raise ValueError("Intentional Error")


def test_serial_executor_synchronous_execution():
    total_tasks = 3
    completed_tasks = []

    def report_function(completed, _):
        completed_tasks.append(completed)

    executor = SerialExecutor(total=total_tasks, report_function=report_function)
    futures = [executor.submit(sample_task, i) for i in range(total_tasks)]

    for i, future in enumerate(futures):
        assert future.result() == i * 2
    assert completed_tasks == [1, 2, 3]


def test_serial_executor_context_management():
    total_tasks = 2
    completed_tasks = []

    def report_function(completed, total):
        completed_tasks.append(completed)

    with SerialExecutor(total=total_tasks, report_function=report_function) as executor:
        futures = [executor.submit(sample_task, i) for i in range(total_tasks)]

        for i, future in enumerate(futures):
            assert future.result() == i * 2
    assert completed_tasks == [1, 2]


def test_serial_executor_handles_exceptions():
    total_tasks = 1

    with SerialExecutor(total=total_tasks) as executor:
        future = executor.submit(sample_failing_task)
        with pytest.raises(ValueError, match="Intentional Error"):
            future.result()


def test_serial_executor_no_report_function():
    total_tasks = 2

    with SerialExecutor(total=total_tasks) as executor:
        futures = [executor.submit(sample_task, i) for i in range(total_tasks)]
        for i, future in enumerate(futures):
            assert future.result() == i * 2