Source code for cookiecutter.main

"""
Main entry point for the `cookiecutter` command.

The code in this module is also a good example of how to use Cookiecutter as a
library rather than a script.
"""

from __future__ import annotations

import logging
import os
import sys
from copy import copy
from pathlib import Path
from typing import Any

from cookiecutter.config import get_user_config
from cookiecutter.exceptions import InvalidModeException
from cookiecutter.generate import generate_context, generate_files
from cookiecutter.hooks import run_pre_prompt_hook
from cookiecutter.prompt import choose_nested_template, prompt_for_config
from cookiecutter.replay import dump, load
from cookiecutter.repository import determine_repo_dir
from cookiecutter.utils import rmtree

logger = logging.getLogger(__name__)


[docs] def cookiecutter( template: str, checkout: str | None = None, no_input: bool = False, extra_context: dict[str, Any] | None = None, replay: bool | str | None = None, overwrite_if_exists: bool = False, output_dir: str = '.', config_file: str | None = None, default_config: bool = False, password: str | None = None, directory: str | None = None, skip_if_file_exists: bool = False, accept_hooks: bool = True, keep_project_on_failure: bool = False, ) -> str: """ Run Cookiecutter just as if using it from the command line. :param template: A directory containing a project template directory, or a URL to a git repository. :param checkout: The branch, tag or commit ID to checkout after clone. :param no_input: Do not prompt for user input. Use default values for template parameters taken from `cookiecutter.json`, user config and `extra_dict`. Force a refresh of cached resources. :param extra_context: A dictionary of context that overrides default and user configuration. :param replay: Do not prompt for input, instead read from saved json. If ``True`` read from the ``replay_dir``. if it exists :param overwrite_if_exists: Overwrite the contents of the output directory if it exists. :param output_dir: Where to output the generated project dir into. :param config_file: User configuration file path. :param default_config: Use default values rather than a config file. :param password: The password to use when extracting the repository. :param directory: Relative path to a cookiecutter template in a repository. :param skip_if_file_exists: Skip the files in the corresponding directories if they already exist. :param accept_hooks: Accept pre and post hooks if set to `True`. :param keep_project_on_failure: If `True` keep generated project directory even when generation fails """ if replay and ((no_input is not False) or (extra_context is not None)): err_msg = ( "You can not use both replay and no_input or extra_context " "at the same time." ) raise InvalidModeException(err_msg) config_dict = get_user_config( config_file=config_file, default_config=default_config, ) base_repo_dir, cleanup_base_repo_dir = determine_repo_dir( template=template, abbreviations=config_dict['abbreviations'], clone_to_dir=config_dict['cookiecutters_dir'], checkout=checkout, no_input=no_input, password=password, directory=directory, ) repo_dir, cleanup = base_repo_dir, cleanup_base_repo_dir # Run pre_prompt hook repo_dir = str(run_pre_prompt_hook(base_repo_dir)) if accept_hooks else repo_dir # Always remove temporary dir if it was created cleanup = repo_dir != base_repo_dir import_patch = _patch_import_path_for_repo(repo_dir) template_name = os.path.basename(os.path.abspath(repo_dir)) if replay: with import_patch: if isinstance(replay, bool): context_from_replayfile = load(config_dict['replay_dir'], template_name) else: path, template_name = os.path.split(os.path.splitext(replay)[0]) context_from_replayfile = load(path, template_name) context_file = os.path.join(repo_dir, 'cookiecutter.json') logger.debug('context_file is %s', context_file) if replay: context = generate_context( context_file=context_file, default_context=config_dict['default_context'], extra_context=None, ) logger.debug('replayfile context: %s', context_from_replayfile) items_for_prompting = { k: v for k, v in context['cookiecutter'].items() if k not in context_from_replayfile['cookiecutter'] } context_for_prompting = {} context_for_prompting['cookiecutter'] = items_for_prompting context = context_from_replayfile logger.debug('prompting context: %s', context_for_prompting) else: context = generate_context( context_file=context_file, default_context=config_dict['default_context'], extra_context=extra_context, ) context_for_prompting = context # preserve the original cookiecutter options # print(context['cookiecutter']) context['_cookiecutter'] = { k: v for k, v in context['cookiecutter'].items() if not k.startswith("_") } # prompt the user to manually configure at the command line. # except when 'no-input' flag is set with import_patch: if {"template", "templates"} & set(context["cookiecutter"].keys()): nested_template = choose_nested_template(context, repo_dir, no_input) return cookiecutter( template=nested_template, checkout=checkout, no_input=no_input, extra_context=extra_context, replay=replay, overwrite_if_exists=overwrite_if_exists, output_dir=output_dir, config_file=config_file, default_config=default_config, password=password, directory=directory, skip_if_file_exists=skip_if_file_exists, accept_hooks=accept_hooks, keep_project_on_failure=keep_project_on_failure, ) if context_for_prompting['cookiecutter']: context['cookiecutter'].update( prompt_for_config(context_for_prompting, no_input) ) logger.debug('context is %s', context) # include template dir or url in the context dict context['cookiecutter']['_template'] = template # include output+dir in the context dict context['cookiecutter']['_output_dir'] = os.path.abspath(output_dir) # include repo dir or url in the context dict context['cookiecutter']['_repo_dir'] = f"{repo_dir}" # include checkout details in the context dict context['cookiecutter']['_checkout'] = checkout dump(config_dict['replay_dir'], template_name, context) # Create project from local context and project template. with import_patch: result = generate_files( repo_dir=repo_dir, context=context, overwrite_if_exists=overwrite_if_exists, skip_if_file_exists=skip_if_file_exists, output_dir=output_dir, accept_hooks=accept_hooks, keep_project_on_failure=keep_project_on_failure, ) # Cleanup (if required) if cleanup: rmtree(repo_dir) if cleanup_base_repo_dir: rmtree(base_repo_dir) return result
class _patch_import_path_for_repo: # noqa: N801 def __init__(self, repo_dir: Path | str) -> None: self._repo_dir = f"{repo_dir}" if isinstance(repo_dir, Path) else repo_dir def __enter__(self) -> None: self._path = copy(sys.path) sys.path.append(self._repo_dir) def __exit__(self, type, value, traceback): # type: ignore[no-untyped-def] sys.path = self._path