from collections import defaultdict
import logging
import os

from parallels.core import messages
from parallels.core.utils.common import mkdir_p
from parallels.core.utils.json_utils import read_json, write_json

logger = logging.getLogger(__name__)


class MigrationState(object):
    """Class is used for storing the global migration state that is kept between migrator runs.

    Additionally in future we can implement storing files here.
    """
    def __init__(self, state_dir):
        """
        :type state_dir: str | unicode
        """
        self._state = defaultdict(dict)
        self._state_dir = state_dir

    def get_vars(self, action=None, subscription=None):
        """Returns dictionary of state variables related with action, subscription or combination of them.

        :type action: str | unicode | None
        :type subscription: str | unicode | None
        :rtype: dict
        """
        key = (action, subscription)
        return self._state[key]

    def save(self):
        items = []
        for key, variables in self._state.iteritems():
            action, subscription = key
            items.append([
                dict(action=action, subscription=subscription),
                variables
            ])
        if not os.path.exists(self._state_dir):
            mkdir_p(self._state_dir)
        write_json(os.path.join(self._state_dir, 'state.json'), items)

    def load(self):
        try:
            items = read_json(os.path.join(self._state_dir, 'state.json'))
            if not items:
                return
            for key_dict, variables in items:
                key = key_dict.get('action'), key_dict.get('subscription')
                self._state[key] = variables
        except Exception:
            logger.debug(messages.LOG_EXCEPTION, exc_info=True)
            logger.warning(messages.FAILED_LOAD_MIGRATION_STATE)
