Python six.viewkeys() Examples

The following are 30 code examples of six.viewkeys(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module six , or try the search function .
Example #1
Source File: linthompsamp.py    From striatum with BSD 2-Clause "Simplified" License 7 votes vote down vote up
def _linthompsamp_score(self, context):
        """Thompson Sampling"""
        action_ids = list(six.viewkeys(context))
        context_array = np.asarray([context[action_id]
                                    for action_id in action_ids])
        model = self._model_storage.get_model()
        B = model['B']  # pylint: disable=invalid-name
        mu_hat = model['mu_hat']
        v = self.R * np.sqrt(24 / self.epsilon
                             * self.context_dimension
                             * np.log(1 / self.delta))
        mu_tilde = self.random_state.multivariate_normal(
            mu_hat.flat, v**2 * np.linalg.inv(B))[..., np.newaxis]
        estimated_reward_array = context_array.dot(mu_hat)
        score_array = context_array.dot(mu_tilde)

        estimated_reward_dict = {}
        uncertainty_dict = {}
        score_dict = {}
        for action_id, estimated_reward, score in zip(
                action_ids, estimated_reward_array, score_array):
            estimated_reward_dict[action_id] = float(estimated_reward)
            score_dict[action_id] = float(score)
            uncertainty_dict[action_id] = float(score - estimated_reward)
        return estimated_reward_dict, uncertainty_dict, score_dict 
Example #2
Source File: clusters_diff.py    From biggraphite with Apache License 2.0 6 votes vote down vote up
def measure_dissymmetry(self, other):
        """Return measure of difference as a Dissymmetry."""
        other_ts_to_val = other.ts_to_val if other else {}
        all_ts_set = six.viewkeys(self.ts_to_val) | six.viewkeys(other_ts_to_val)

        if not all_ts_set:
            return None

        val_tuples = [
            (self.ts_to_val.get(ts), other_ts_to_val.get(ts)) for ts in all_ts_set
        ]
        diff_measures = [
            self._measure_relative_gap(val1, val2) for val1, val2 in val_tuples
        ]

        return Dissymmetry(self.name, diff_measures) 
Example #3
Source File: b301_b302_b305.py    From flake8-bugbear with MIT License 6 votes vote down vote up
def this_is_okay():
    d = {}
    iterkeys(d)
    six.iterkeys(d)
    six.itervalues(d)
    six.iteritems(d)
    six.iterlists(d)
    six.viewkeys(d)
    six.viewvalues(d)
    six.viewlists(d)
    itervalues(d)
    future.utils.iterkeys(d)
    future.utils.itervalues(d)
    future.utils.iteritems(d)
    future.utils.iterlists(d)
    future.utils.viewkeys(d)
    future.utils.viewvalues(d)
    future.utils.viewlists(d)
    six.next(d)
    builtins.next(d) 
Example #4
Source File: predicates.py    From catalyst with Apache License 2.0 6 votes vote down vote up
def assert_dict_equal(result, expected, path=(), msg='', **kwargs):
    _check_sets(
        viewkeys(result),
        viewkeys(expected),
        msg,
        path + ('.%s()' % ('viewkeys' if PY2 else 'keys'),),
        'key',
    )

    failures = []
    for k, (resultv, expectedv) in iteritems(dzip_exact(result, expected)):
        try:
            assert_equal(
                resultv,
                expectedv,
                path=path + ('[%r]' % (k,),),
                msg=msg,
                **kwargs
            )
        except AssertionError as e:
            failures.append(str(e))

    if failures:
        raise AssertionError('\n'.join(failures)) 
Example #5
Source File: test_assets.py    From catalyst with Apache License 2.0 6 votes vote down vote up
def test_blocked_lookup_symbol_query(self):
        # we will try to query for more variables than sqlite supports
        # to make sure we are properly chunking on the client side
        as_of = pd.Timestamp('2013-01-01', tz='UTC')
        # we need more sids than we can query from sqlite
        nsids = SQLITE_MAX_VARIABLE_NUMBER + 10
        sids = range(nsids)
        frame = pd.DataFrame.from_records(
            [
                {
                    'sid': sid,
                    'symbol':  'TEST.%d' % sid,
                    'start_date': as_of.value,
                    'end_date': as_of.value,
                    'exchange': uuid.uuid4().hex
                }
                for sid in sids
            ]
        )
        self.write_assets(equities=frame)
        assets = self.asset_finder.retrieve_equities(sids)
        assert_equal(viewkeys(assets), set(sids)) 
Example #6
Source File: eval_coco_format.py    From MAX-Image-Segmenter with Apache License 2.0 6 votes vote down vote up
def _is_thing_array(categories_json, ignored_label):
  """is_thing[category_id] is a bool on if category is "thing" or "stuff"."""
  is_thing_dict = {}
  for category_json in categories_json:
    is_thing_dict[category_json['id']] = bool(category_json['isthing'])

  # Check our assumption that the category ids are consecutive.
  # Usually metrics should be able to handle this case, but adding a warning
  # here.
  max_category_id = max(six.iterkeys(is_thing_dict))
  if len(is_thing_dict) != max_category_id + 1:
    seen_ids = six.viewkeys(is_thing_dict)
    all_ids = set(six.moves.range(max_category_id + 1))
    unseen_ids = all_ids.difference(seen_ids)
    if unseen_ids != {ignored_label}:
      logging.warning(
          'Nonconsecutive category ids or no category JSON specified for ids: '
          '%s', unseen_ids)

  is_thing_array = np.zeros(max_category_id + 1)
  for category_id, is_thing in six.iteritems(is_thing_dict):
    is_thing_array[category_id] = is_thing

  return is_thing_array 
Example #7
Source File: api.py    From magnum with Apache License 2.0 6 votes vote down vote up
def _do_update_cluster_template(self, cluster_template_id, values):
        session = get_session()
        with session.begin():
            query = model_query(models.ClusterTemplate, session=session)
            query = add_identity_filter(query, cluster_template_id)
            try:
                ref = query.with_lockmode('update').one()
            except NoResultFound:
                raise exception.ClusterTemplateNotFound(
                    clustertemplate=cluster_template_id)

            if self._is_cluster_template_referenced(session, ref['uuid']):
                # NOTE(flwang): We only allow to update ClusterTemplate to be
                # public, hidden and rename
                if (not self._is_publishing_cluster_template(values) and
                        list(six.viewkeys(values)) != ["name"]):
                    raise exception.ClusterTemplateReferenced(
                        clustertemplate=cluster_template_id)

            ref.update(values)
        return ref 
Example #8
Source File: tags.py    From ida-minsc with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def frame(cls, F, *tags):
        '''Iterate through each field containing the specified `tags` within the frame belonging to the function `ea`.'''
        global read, internal
        tags_ = { tag for tag in tags }

        for ofs, item in read.frame(F):
            field, type, comment = item

            # if the entire comment is in tags (like None) or no tags were specified, then save the entire member
            if not tags or comment in tags_:
                yield ofs, item
                continue

            # otherwise, decode the comment into a dictionary using only the tags the user asked for
            comment_ = internal.comment.decode(comment)
            res = { name : comment_[name] for name in six.viewkeys(comment_) & tags_ }

            # if anything was found, then re-encode it and yield to the user
            if res: yield ofs, (field, type, internal.comment.encode(res))
        return

    ## query the entire database for the specified tags 
Example #9
Source File: __init__.py    From treadmill with Apache License 2.0 6 votes vote down vote up
def _gc(self):
        """Remove disconnected websocket handlers."""
        for directory in list(six.viewkeys(self.handlers)):
            handlers = [
                (pattern, handler, impl, sub_id)
                for pattern, handler, impl, sub_id in self.handlers[directory]
                if handler.active(sub_id=sub_id)
            ]

            _LOGGER.debug('Number of active handlers for %s: %s',
                          directory, len(handlers))

            if not handlers:
                _LOGGER.debug('No active handlers for %s', directory)
                self.handlers.pop(directory, None)
                if directory not in self.watch_dirs:
                    # Watch is not permanent, remove dir from watcher.
                    self.watcher.remove_dir(directory)
            else:
                self.handlers[directory] = handlers 
Example #10
Source File: containers_test.py    From dm_control with Apache License 2.0 6 votes vote down vote up
def test_query_tag_intersection(self, query, expected_keys):
    tasks = containers.TaggedTasks()

    # pylint: disable=unused-variable
    @tasks.add('a', 'b')
    def f1():
      pass

    @tasks.add('a', 'b', 'c')
    def f2():
      pass

    @tasks.add('a', 'c', 'd')
    def f3():
      pass
    # pylint: enable=unused-variable

    result = tasks.tagged(*query)
    self.assertSetEqual(frozenset(six.viewkeys(result)), expected_keys) 
Example #11
Source File: containers.py    From dm_control with Apache License 2.0 6 votes vote down vote up
def tagged(self, *tags):
    """Returns a (possibly empty) dict of functions matching all the given tags.

    Args:
      *tags: Strings specifying tags to query by.

    Returns:
      A dict of `{name: function}` containing all the functions that are tagged
      by all of the strings in `tags`.
    """
    if not tags:
      return {}
    tags = set(tags)
    if not tags.issubset(six.viewkeys(self._tags)):
      return {}
    names = six.viewkeys(self._tags[tags.pop()])
    while tags:
      names &= six.viewkeys(self._tags[tags.pop()])
    return {name: self._tasks[name] for name in names} 
Example #12
Source File: eval_coco_format.py    From models with Apache License 2.0 6 votes vote down vote up
def _is_thing_array(categories_json, ignored_label):
  """is_thing[category_id] is a bool on if category is "thing" or "stuff"."""
  is_thing_dict = {}
  for category_json in categories_json:
    is_thing_dict[category_json['id']] = bool(category_json['isthing'])

  # Check our assumption that the category ids are consecutive.
  # Usually metrics should be able to handle this case, but adding a warning
  # here.
  max_category_id = max(six.iterkeys(is_thing_dict))
  if len(is_thing_dict) != max_category_id + 1:
    seen_ids = six.viewkeys(is_thing_dict)
    all_ids = set(six.moves.range(max_category_id + 1))
    unseen_ids = all_ids.difference(seen_ids)
    if unseen_ids != {ignored_label}:
      logging.warning(
          'Nonconsecutive category ids or no category JSON specified for ids: '
          '%s', unseen_ids)

  is_thing_array = np.zeros(max_category_id + 1)
  for category_id, is_thing in six.iteritems(is_thing_dict):
    is_thing_array[category_id] = is_thing

  return is_thing_array 
Example #13
Source File: _utils.py    From ida-minsc with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def transform(translate, *names):
    '''This applies the callable `translate` to any function arguments that match `names` in the decorated function.'''
    names = {name for name in names}
    def wrapper(F, *rargs, **rkwds):
        f = wrap.extract(F)
        argnames, defaults, (wildname, _) = wrap.arguments(f)

        # convert any positional arguments
        res = ()
        for value, argname in zip(rargs, argnames):
            res += (translate(value) if argname in names else value),

        # get the rest
        for value in rargs[len(res):]:
            res += (translate(value) if wildname in names else value,)

        # convert any keywords arguments
        kwds = dict(rkwds)
        for argname in six.viewkeys(rkwds) & names:
            kwds[argname] = translate(kwds[argname])
        return F(*res, **kwds)

    # decorater that wraps the function `F` with `wrapper`.
    def result(F):
        return wrap(F, wrapper)
    return result 
Example #14
Source File: data_generator.py    From feagen with BSD 2-Clause "Simplified" License 5 votes vote down vote up
def __init__(self, handlers):
        handler_set = set(six.viewkeys(handlers))
        if handler_set != self._handler_set:
            redundant_handlers_set = handler_set - self._handler_set
            lacked_handlers_set = self._handler_set - handler_set
            raise ValueError('Handler set mismatch. {} redundant and {} lacked.'
                             .format(redundant_handlers_set,
                                     lacked_handlers_set))
        self._handlers = handlers 
Example #15
Source File: xdict.py    From python-esppy with Apache License 2.0 5 votes vote down vote up
def viewflatkeys(self):
        ''' Return view of flattened keys '''
        return six.viewkeys(self.flattened()) 
Example #16
Source File: test_utils.py    From models with Apache License 2.0 5 votes vote down vote up
def read_segmentation_with_rgb_color_map(image_testdata_path,
                                         rgb_to_semantic_label,
                                         output_dtype=None):
  """Reads a test segmentation as an image and a map from colors to labels.

  Args:
    image_testdata_path: Image path relative to panoptic_segmentation/testdata
      as a string.
    rgb_to_semantic_label: Mapping from RGB colors to integer labels as a
      dictionary.
    output_dtype: Type of the output labels. If None, defaults to the type of
      the provided color map.

  Returns:
    A 2D numpy array of labels.

  Raises:
    ValueError: On an incomplete `rgb_to_semantic_label`.
  """
  rgb_image = read_test_image(image_testdata_path, mode='RGB')
  if len(rgb_image.shape) != 3 or rgb_image.shape[2] != 3:
    raise AssertionError(
        'Expected RGB image, actual shape is %s' % rgb_image.sape)

  num_pixels = rgb_image.shape[0] * rgb_image.shape[1]
  unique_colors = np.unique(np.reshape(rgb_image, [num_pixels, 3]), axis=0)
  if not set(map(tuple, unique_colors)).issubset(
      six.viewkeys(rgb_to_semantic_label)):
    raise ValueError('RGB image has colors not in color map.')

  output_dtype = output_dtype or type(
      next(six.itervalues(rgb_to_semantic_label)))
  output_labels = np.empty(rgb_image.shape[:2], dtype=output_dtype)
  for rgb_color, int_label in six.iteritems(rgb_to_semantic_label):
    color_array = np.array(rgb_color, ndmin=3)
    output_labels[np.all(rgb_image == color_array, axis=2)] = int_label
  return output_labels 
Example #17
Source File: _interface.py    From ida-minsc with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def hook(self):
        '''Physically connect all of the hooks controlled by this class.'''
        notok = False

        # Just iterate through each target and connect a closure for it
        for target in self.__cache__.viewkeys():
            ok = self.connect(target, self.apply(target))
            if not ok:
                logging.warn(u"{:s}.cycle() : Error trying to connect to the specified {:s}.".format('.'.join(('internal', __name__, self.__class__.__name__)), self.__format__(target)))
                notok = True
            continue
        return not notok 
Example #18
Source File: _interface.py    From ida-minsc with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def unhook(self):
        '''Physically disconnect all of the hooks controlled by this class.'''
        notok = False

        # Simply disconnect everything
        for target in self.__cache__.viewkeys():
            ok = self.disconnect(target)
            if not ok:
                logging.warn(u"{:s}.cycle() : Error trying to disconnect from the specified {:s}.".format('.'.join(('internal', __name__, self.__class__.__name__)), self.__format__(target)))
                notok = True
            continue
        return not notok 
Example #19
Source File: _interface.py    From ida-minsc with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def disable(self, target):
        '''Disable execution of all the callables for the specified `target`.'''
        cls = self.__class__
        if target not in self.__cache__:
            logging.fatal(u"{:s}.disable({!r}) : The requested {:s} does not exist. Available hooks are: {:s}.".format('.'.join(('internal', __name__, cls.__name__)), target, self.__formatter__(target), "{{{:s}}}".format(', '.join(map("{!r}".format, self.__cache__.viewkeys())))))
            return False
        if target in self.__disabled:
            logging.warn(u"{:s}.disable({!r}) : {:s} has already been disabled. Currently disabled hooks are: {:s}.".format('.'.join(('internal', __name__, cls.__name__)), target, self.__formatter__(target).capitalize(), "{{{:s}}}".format(', '.join(map("{!r}".format, self.__disabled)))))
            return False
        self.__disabled.add(target)
        return True 
Example #20
Source File: _interface.py    From ida-minsc with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def _replace(self, **fields):
        '''Assign the specified `fields` to the fields within the tuple.'''
        fc = fields.copy()
        result = self._make(map(fc.pop, self._fields, self))
        if fc:
            cls = self.__class__
            logging.warn(u"{:s}._replace({:s}) : Unable to assign unknown field names ({:s}) to tuple.".format('.'.join(('internal', __name__, cls.__name__)), internal.utils.string.kwargs(fields), '{' + ', '.join(map(internal.utils.string.repr, six.viewkeys(fc))) + '}'))
        return result 
Example #21
Source File: html5parser.py    From bazarr with GNU General Public License v3.0 5 votes vote down vote up
def adjust_attributes(token, replacements):
    needs_adjustment = viewkeys(token['data']) & viewkeys(replacements)
    if needs_adjustment:
        token['data'] = OrderedDict((replacements.get(k, k), v)
                                    for k, v in token['data'].items()) 
Example #22
Source File: _comment.py    From ida-minsc with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def name(cls, address, **target):
        """Return all the tag names (``set``) for the contents of the function `target`.

        If `target` is undefined or ``None`` then use `address` to locate the function.
        """
        res = cls._read(target.get('target', None), address) or {}
        res = res.get(cls.__tags__, {})
        return set(six.viewkeys(res)) 
Example #23
Source File: _comment.py    From ida-minsc with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def address(cls, address, **target):
        """Return all the addresses (``sorted``) with tags in the contents for the function `target`.

        If `target` is undefined or ``None`` then use `address` to locate the function.
        """
        res = cls._read(target.get('target', None), address) or {}
        res = res.get(cls.__address__, {})
        return sorted(six.viewkeys(res)) 
Example #24
Source File: tags.py    From ida-minsc with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def globals(Globals, **tagmap):
        '''Apply the tags in `Globals` back into the database.'''
        global apply
        cls, tagmap_output = apply.__class__, u", {:s}".format(u', '.join(u"{:s}={:s}".format(internal.utils.string.escape(oldtag), internal.utils.string.escape(newtag)) for oldtag, newtag in six.iteritems(tagmap))) if tagmap else ''

        count = 0
        for ea, res in Globals:
            ns = func if func.within(ea) else db

            # grab the current (old) tag state
            state = ns.tag(ea)

            # transform the new tag state using the tagmap
            new = { tagmap.get(name, name) : value for name, value in six.viewitems(res) }

            # check if the tag mapping resulted in the deletion of a tag
            if len(new) != len(res):
                for name in six.viewkeys(res) - six.viewkeys(new):
                    logging.warn(u"{:s}.globals(...{:s}) : Refusing requested tag mapping as it results in the tag \"{:s}\" overwriting the tag \"{:s}\" in the global {:#x}. The value {!s} would be replaced with {!s}.".format('.'.join((__name__, cls.__name__)), tagmap_output, internal.utils.string.escape(name, '"'), internal.utils.string.escape(tagmap[name], '"'), ea, internal.utils.string.repr(res[name]), internal.utils.string.repr(res[tagmap[name]])))
                pass

            # check what's going to be overwritten with different values prior to doing it
            for name in six.viewkeys(state) & six.viewkeys(new):
                if state[name] == new[name]: continue
                logging.warn(u"{:s}.globals(...{:s}) : Overwriting tag \"{:s}\" for global at {:#x} with new value {!s}. Old value was {!s}.".format('.'.join((__name__, cls.__name__)), tagmap_output, internal.utils.string.escape(name, '"'), ea, internal.utils.string.repr(new[name]), internal.utils.string.repr(state[name])))

            # now we can apply the tags to the global address
            try:
                [ ns.tag(ea, name, value) for name, value in six.iteritems(new) if state.get(name, dummy) != value ]
            except:
                logging.warn(u"{:s}.globals(...{:s}) : Unable to apply tags ({!s}) to global {:#x}.".format('.'.join((__name__, cls.__name__)), tagmap_output, internal.utils.string.repr(new), ea), exc_info=True)

            # increase our counter
            count += 1
        return count

    ## applying contents tags to all the functions 
Example #25
Source File: tags.py    From ida-minsc with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def contents(Contents, **tagmap):
        '''Apply the tags in `Contents` back into each function within the database.'''
        global apply
        cls, tagmap_output = apply.__class__, u", {:s}".format(u', '.join(u"{:s}={:s}".format(internal.utils.string.escape(oldtag), internal.utils.string.escape(newtag)) for oldtag, newtag in six.iteritems(tagmap))) if tagmap else ''

        count = 0
        for loc, res in Contents:
            ea = locationToAddress(loc)

            # warn the user if this address is not within a function
            if not func.within(ea):
                logging.warn(u"{:s}.contents(...{:s}) : Address {:#x} is not within a function. Using a global tag.".format('.'.join((__name__, cls.__name__)), tagmap_output, ea))

            # grab the current (old) tag state
            state = db.tag(ea)

            # transform the new tag state using the tagmap
            new = { tagmap.get(name, name) : value for name, value in six.viewitems(res) }

            # check if the tag mapping resulted in the deletion of a tag
            if len(new) != len(res):
                for name in six.viewkeys(res) - six.viewkeys(new):
                    logging.warn(u"{:s}.contents(...{:s}) : Refusing requested tag mapping as it results in the tag \"{:s}\" overwriting tag \"{:s}\" for the contents at {:#x}. The value {!s} would be overwritten by {!s}.".format('.'.join((__name__, cls.__name__)), tagmap_output, internal.utils.string.escape(name, '"'), internal.utils.string.escape(tagmap[name], '"'), ea, internal.utils.string.repr(res[name]), internal.utils.string.repr(res[tagmap[name]])))
                pass

            # inform the user if any tags are being overwritten with different values
            for name in six.viewkeys(state) & six.viewkeys(new):
                if state[name] == new[name]: continue
                logging.warn(u"{:s}.contents(...{:s}) : Overwriting contents tag \"{:s}\" for address {:#x} with new value {!s}. Old value was {!s}.".format('.'.join((__name__, cls.__name__)), tagmap_output, internal.utils.string.escape(name, '"'), ea, internal.utils.string.repr(new[name]), internal.utils.string.repr(state[name])))

            # write the tags to the contents address
            try:
                [ db.tag(ea, name, value) for name, value in six.iteritems(new) if state.get(name, dummy) != value ]
            except:
                logging.warn(u"{:s}.contents(...{:s}) : Unable to apply tags {!s} to location {:#x}.".format('.'.join((__name__, cls.__name__)), tagmap_output, internal.utils.string.repr(new), ea), exc_info=True)

            # increase our counter
            count += 1
        return count

    ## applying frames to all the functions 
Example #26
Source File: breakpoints_manager.py    From cloud-debug-python with Apache License 2.0 5 votes vote down vote up
def SetActiveBreakpoints(self, breakpoints_data):
    """Adds new breakpoints and removes missing ones.

    Args:
      breakpoints_data: updated list of active breakpoints.
    """
    with self._lock:
      ids = set([x['id'] for x in breakpoints_data])

      # Clear breakpoints that no longer show up in active breakpoints list.
      for breakpoint_id in six.viewkeys(self._active) - ids:
        self._active.pop(breakpoint_id).Clear()

      # Create new breakpoints.
      self._active.update([
          (x['id'],
           python_breakpoint.PythonBreakpoint(
               x,
               self._hub_client,
               self,
               self.data_visibility_policy))
          for x in breakpoints_data
          if x['id'] in ids - six.viewkeys(self._active) - self._completed])

      # Remove entries from completed_breakpoints_ that weren't listed in
      # breakpoints_data vector. These are confirmed to have been removed by the
      # hub and the debuglet can now assume that they will never show up ever
      # again. The backend never reuses breakpoint IDs.
      self._completed &= ids

      if self._active:
        self._next_expiration = datetime.min  # Not known.
      else:
        self._next_expiration = datetime.max  # Nothing to expire. 
Example #27
Source File: cache.py    From odoo-rpc-client with Mozilla Public License 2.0 5 votes vote down vote up
def update_keys(self, keys):
        """ Add new IDs to cache.

            :param list keys: list of new IDs to be added to cache
            :return: self
            :rtype: ObjectCache
        """
        if not self:
            # for large amounts of data, this may be faster (no need for set
            # and difference calls)
            self.update({cid: {'id': cid} for cid in keys})
        else:
            self.update({cid: {'id': cid}
                         for cid in set(keys).difference(six.viewkeys(self))})
        return self 
Example #28
Source File: item.py    From scrapy-jsonschema with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def _merge_schema(base, new):
    if base is None or new is None:
        return base or new

    if all(isinstance(x, dict) for x in (base, new)):
        return {
            key: _merge_schema(base.get(key), new.get(key))
            for key in six.viewkeys(base) | six.viewkeys(new)
        }
    if all(isinstance(x, (list, tuple)) for x in (base, new)):
        return list(base) + list(new)
    return base 
Example #29
Source File: exp4p.py    From striatum with BSD 2-Clause "Simplified" License 5 votes vote down vote up
def _exp4p_score(self, context):
        """The main part of Exp4.P.
        """
        advisor_ids = list(six.viewkeys(context))

        w = self._modelstorage.get_model()['w']
        if len(w) == 0:
            for i in advisor_ids:
                w[i] = 1
        w_sum = sum(six.viewvalues(w))

        action_probs_list = []
        for action_id in self.action_ids:
            weighted_exp = [w[advisor_id] * context[advisor_id][action_id]
                            for advisor_id in advisor_ids]
            prob_vector = np.sum(weighted_exp) / w_sum
            action_probs_list.append((1 - self.n_actions * self.p_min)
                                     * prob_vector
                                     + self.p_min)
        action_probs_list = np.asarray(action_probs_list)
        action_probs_list /= action_probs_list.sum()

        estimated_reward = {}
        uncertainty = {}
        score = {}
        for action_id, action_prob in zip(self.action_ids, action_probs_list):
            estimated_reward[action_id] = action_prob
            uncertainty[action_id] = 0
            score[action_id] = action_prob
        self._modelstorage.save_model(
            {'action_probs': estimated_reward, 'w': w})

        return estimated_reward, uncertainty, score 
Example #30
Source File: exp4p.py    From striatum with BSD 2-Clause "Simplified" License 5 votes vote down vote up
def reward(self, history_id, rewards):
        """Reward the previous action with reward.

        Parameters
        ----------
        history_id : int
            The history id of the action to reward.

        rewards : dictionary
            The dictionary {action_id, reward}, where reward is a float.
        """
        context = (self._historystorage
                   .get_unrewarded_history(history_id)
                   .context)

        model = self._modelstorage.get_model()
        w = model['w']
        action_probs = model['action_probs']
        action_ids = list(six.viewkeys(six.next(six.itervalues(context))))

        # Update the model
        for action_id, reward in six.viewitems(rewards):
            y_hat = {}
            v_hat = {}
            for i in six.viewkeys(context):
                y_hat[i] = (context[i][action_id] * reward
                            / action_probs[action_id])
                v_hat[i] = sum(
                    [context[i][k] / action_probs[k] for k in action_ids])
                w[i] = w[i] * np.exp(
                    self.p_min / 2
                    * (y_hat[i] + v_hat[i]
                       * np.sqrt(np.log(len(context) / self.delta)
                                 / (len(action_ids) * self.max_rounds))))

        self._modelstorage.save_model({
            'action_probs': action_probs, 'w': w})

        # Update the history
        self._historystorage.add_reward(history_id, rewards)