Skip to content

Databricks Testing Tools

All utilities to detect databricks notebooks tests and execute then on a databricks cluster.

Source code in databricks_testing_tools/testing_tools.py
class DatabricksTestingTools:
    """
    All utilities to detect databricks notebooks tests and execute then on a databricks cluster.
    """

    def __init__(self) -> None:
        """
        Init the testing tools.
        """
        self.databricks_api_client = DatabricksAPI(host=os.environ.get('DATABRICKS_HOST'),
                                                   token=os.environ.get('DATABRICKS_TOKEN'),
                                                   api_version="2.0",
                                                   jobs_api_version="2.1")

    def list_test_notebooks(self, test_dir: str) -> Generator[NotebookObject, None, None]:
        """
        List the tests notebooks (starting with 'test_') recursively from a directory.

        Args:
            test_dir (str): the directory to search the tests from

        Returns:
            Generator[NotebookObject, None, None]: a generator of NotebookObject
        """
        objects = self.databricks_api_client.workspace.list(test_dir)
        workspace_objects = WorkspacePath.from_api_response(objects)
        for notebook in workspace_objects.test_notebooks:
            yield notebook
        for directory in workspace_objects.directories:
            for notebook in self.list_test_notebooks(directory.path):
                yield notebook

    def submit_job(self, notebook: NotebookObject, cluster_id: str, output: str,
                   extra_widgets: Dict[str, str] = {}) -> str:
        """
        Submit a python notebook job on a databricks cluster and store results in output directory.

        Args:
            notebook (NotebookObject): the notebook to run
            cluster_id (str): the cluster id
            output (str): the output directory where tests results will be stored.
            extra_widgets (Dict[str, str]): extra widgets to add to the databricks notebook job

        Returns:
            str: the run id of the submitted job
        """
        parameters = {"output": output}
        for key, value in extra_widgets:
            parameters[key] = value
        response = self.databricks_api_client.jobs.submit_run(run_name=notebook.name + '_' + str(uuid.uuid4()),
                                                              existing_cluster_id=cluster_id,
                                                              notebook_task={
                                                                  "notebook_path": notebook.path,
                                                                  "base_parameters": parameters
                                                              })
        return response['run_id']

    def wait_for_job_completion(self, run_id: str, poll_wait_time: int = 10) -> bool:
        """
        Wait for a run to complete.

        Args:
            run_id (str): the id of the run
            poll_wait_time (int): The number of seconds to wait before polling databricks jobs status. Defaults to 10

        Returns:
            bool: if the job complete successfully or not
        """
        while True:
            time.sleep(poll_wait_time)
            response = self.databricks_api_client.jobs.get_run(run_id)
            current_state = response['state']['life_cycle_state']

            if current_state in ['TERMINATED', 'INTERNAL_ERROR', 'SKIPPED']:
                return response['state']['result_state'] == 'SUCCESS'

    def run_notebook_tests(self, cluster_id: str, output: str, test_dir: str = '', notebooks_paths: List[str] = [],
                           poll_wait_time: int = 10, nb_thread: int = 1, extra_widgets: Dict[str, str] = {}) -> None:
        """
        Search for notebook tests (name that start with 'test_') recursively  from a directory and run them on
        a databricks cluster. Save the tests results to an output directory.

        Args:
            cluster_id (str): the cluster id
            output (str): the output directory where tests results will be stored.
            test_dir (str): the directory to search the tests from
            notebooks_paths (str): the paths of notebooks to test if test_dir is empty
            poll_wait_time (int): The number of seconds to wait before polling databricks jobs status. Defaults to 10
            nb_thread (int): The number of threads to execute the tests
            extra_widgets (Dict[str, str]): extra widgets to add to test notebooks

        """

        def run_notebook_test(notebook) -> None:
            logging.getLogger(__name__).info(
                f"Starting the job for notebook tests: {notebook.path}")
            run_id = self.submit_job(notebook, cluster_id, output)
            state = self.wait_for_job_completion(run_id, poll_wait_time)
            logging.getLogger(__name__).info(
                f"The job for the notebook tests {notebook.path} has ended with status {state}")

        if test_dir:
            list_nb = self.list_test_notebooks(test_dir)
        else:
            list_nb = [NotebookObject(path) for path in notebooks_paths]

        pool = ThreadPool(nb_thread)
        pool.map(run_notebook_test, list_nb)

__init__(self) -> None special

Init the testing tools.

Source code in databricks_testing_tools/testing_tools.py
def __init__(self) -> None:
    """
    Init the testing tools.
    """
    self.databricks_api_client = DatabricksAPI(host=os.environ.get('DATABRICKS_HOST'),
                                               token=os.environ.get('DATABRICKS_TOKEN'),
                                               api_version="2.0",
                                               jobs_api_version="2.1")

list_test_notebooks(self, test_dir: str) -> Generator[databricks_testing_tools.api_client_results.notebook_object.NotebookObject, NoneType, NoneType]

List the tests notebooks (starting with 'test_') recursively from a directory.

Parameters:

Name Type Description Default
test_dir str

the directory to search the tests from

required

Returns:

Type Description
Generator[NotebookObject, None, None]

a generator of NotebookObject

Source code in databricks_testing_tools/testing_tools.py
def list_test_notebooks(self, test_dir: str) -> Generator[NotebookObject, None, None]:
    """
    List the tests notebooks (starting with 'test_') recursively from a directory.

    Args:
        test_dir (str): the directory to search the tests from

    Returns:
        Generator[NotebookObject, None, None]: a generator of NotebookObject
    """
    objects = self.databricks_api_client.workspace.list(test_dir)
    workspace_objects = WorkspacePath.from_api_response(objects)
    for notebook in workspace_objects.test_notebooks:
        yield notebook
    for directory in workspace_objects.directories:
        for notebook in self.list_test_notebooks(directory.path):
            yield notebook

run_notebook_tests(self, cluster_id: str, output: str, test_dir: str = '', notebooks_paths: List[str] = [], poll_wait_time: int = 10, nb_thread: int = 1, extra_widgets: Dict[str, str] = {}) -> None

Search for notebook tests (name that start with 'test_') recursively from a directory and run them on a databricks cluster. Save the tests results to an output directory.

Parameters:

Name Type Description Default
cluster_id str

the cluster id

required
output str

the output directory where tests results will be stored.

required
test_dir str

the directory to search the tests from

''
notebooks_paths str

the paths of notebooks to test if test_dir is empty

[]
poll_wait_time int

The number of seconds to wait before polling databricks jobs status. Defaults to 10

10
nb_thread int

The number of threads to execute the tests

1
extra_widgets Dict[str, str]

extra widgets to add to test notebooks

{}
Source code in databricks_testing_tools/testing_tools.py
def run_notebook_tests(self, cluster_id: str, output: str, test_dir: str = '', notebooks_paths: List[str] = [],
                       poll_wait_time: int = 10, nb_thread: int = 1, extra_widgets: Dict[str, str] = {}) -> None:
    """
    Search for notebook tests (name that start with 'test_') recursively  from a directory and run them on
    a databricks cluster. Save the tests results to an output directory.

    Args:
        cluster_id (str): the cluster id
        output (str): the output directory where tests results will be stored.
        test_dir (str): the directory to search the tests from
        notebooks_paths (str): the paths of notebooks to test if test_dir is empty
        poll_wait_time (int): The number of seconds to wait before polling databricks jobs status. Defaults to 10
        nb_thread (int): The number of threads to execute the tests
        extra_widgets (Dict[str, str]): extra widgets to add to test notebooks

    """

    def run_notebook_test(notebook) -> None:
        logging.getLogger(__name__).info(
            f"Starting the job for notebook tests: {notebook.path}")
        run_id = self.submit_job(notebook, cluster_id, output)
        state = self.wait_for_job_completion(run_id, poll_wait_time)
        logging.getLogger(__name__).info(
            f"The job for the notebook tests {notebook.path} has ended with status {state}")

    if test_dir:
        list_nb = self.list_test_notebooks(test_dir)
    else:
        list_nb = [NotebookObject(path) for path in notebooks_paths]

    pool = ThreadPool(nb_thread)
    pool.map(run_notebook_test, list_nb)

submit_job(self, notebook: NotebookObject, cluster_id: str, output: str, extra_widgets: Dict[str, str] = {}) -> str

Submit a python notebook job on a databricks cluster and store results in output directory.

Parameters:

Name Type Description Default
notebook NotebookObject

the notebook to run

required
cluster_id str

the cluster id

required
output str

the output directory where tests results will be stored.

required
extra_widgets Dict[str, str]

extra widgets to add to the databricks notebook job

{}

Returns:

Type Description
str

the run id of the submitted job

Source code in databricks_testing_tools/testing_tools.py
def submit_job(self, notebook: NotebookObject, cluster_id: str, output: str,
               extra_widgets: Dict[str, str] = {}) -> str:
    """
    Submit a python notebook job on a databricks cluster and store results in output directory.

    Args:
        notebook (NotebookObject): the notebook to run
        cluster_id (str): the cluster id
        output (str): the output directory where tests results will be stored.
        extra_widgets (Dict[str, str]): extra widgets to add to the databricks notebook job

    Returns:
        str: the run id of the submitted job
    """
    parameters = {"output": output}
    for key, value in extra_widgets:
        parameters[key] = value
    response = self.databricks_api_client.jobs.submit_run(run_name=notebook.name + '_' + str(uuid.uuid4()),
                                                          existing_cluster_id=cluster_id,
                                                          notebook_task={
                                                              "notebook_path": notebook.path,
                                                              "base_parameters": parameters
                                                          })
    return response['run_id']

wait_for_job_completion(self, run_id: str, poll_wait_time: int = 10) -> bool

Wait for a run to complete.

Parameters:

Name Type Description Default
run_id str

the id of the run

required
poll_wait_time int

The number of seconds to wait before polling databricks jobs status. Defaults to 10

10

Returns:

Type Description
bool

if the job complete successfully or not

Source code in databricks_testing_tools/testing_tools.py
def wait_for_job_completion(self, run_id: str, poll_wait_time: int = 10) -> bool:
    """
    Wait for a run to complete.

    Args:
        run_id (str): the id of the run
        poll_wait_time (int): The number of seconds to wait before polling databricks jobs status. Defaults to 10

    Returns:
        bool: if the job complete successfully or not
    """
    while True:
        time.sleep(poll_wait_time)
        response = self.databricks_api_client.jobs.get_run(run_id)
        current_state = response['state']['life_cycle_state']

        if current_state in ['TERMINATED', 'INTERNAL_ERROR', 'SKIPPED']:
            return response['state']['result_state'] == 'SUCCESS'