Python six.ensure_binary() Examples

The following are 30 code examples of six.ensure_binary(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module six , or try the search function .
Example #1
Source File: metadata.py    From vessel-classification with Apache License 2.0 6 votes vote down vote up
def __init__(self,
                 metadata_dict,
                 fishing_ranges_map):
        self.metadata_by_split = metadata_dict
        self.metadata_by_id = {}
        self.fishing_ranges_map = fishing_ranges_map
        self.id_map_int2bytes = {}
        for split, vessels in metadata_dict.items():
            for id_, data in vessels.items():
                id_ = six.ensure_binary(id_)
                self.metadata_by_id[id_] = data
                idhash = stable_hash(id_)
                self.id_map_int2bytes[idhash] = id_

        intersection_ids = set(self.metadata_by_id.keys()).intersection(
            set(fishing_ranges_map.keys()))
        logging.info("Metadata for %d ids.", len(self.metadata_by_id))
        logging.info("Fishing ranges for %d ids.", len(fishing_ranges_map))
        logging.info("Vessels with both types of data: %d",
                     len(intersection_ids)) 
Example #2
Source File: utils.py    From git-pw with MIT License 6 votes vote down vote up
def _echo_via_pager(pager, output):
    env = dict(os.environ)
    # When the LESS environment variable is unset, Git sets it to FRX (if
    # LESS environment variable is set, Git does not change it at all).
    if 'LESS' not in env:
        env['LESS'] = 'FRX'

    pager = subprocess.Popen(pager.split(), stdin=subprocess.PIPE, env=env)

    output = six.ensure_binary(output)

    try:
        pager.communicate(input=output)
    except (IOError, KeyboardInterrupt):
        pass
    else:
        pager.stdin.close()

    while True:
        try:
            pager.wait()
        except KeyboardInterrupt:
            pass
        else:
            break 
Example #3
Source File: epubmerge.py    From EpubMerge with GNU General Public License v3.0 6 votes vote down vote up
def ensure_binary(s, encoding='utf-8', errors='strict'):
        """Coerce **s** to six.binary_type.

        For Python 2:
          - `unicode` -> encoded to `str`
          - `str` -> `str`

        For Python 3:
          - `str` -> encoded to `bytes`
          - `bytes` -> `bytes`
        """
        if isinstance(s, text_type):
            return s.encode(encoding, errors)
        elif isinstance(s, binary_type):
            return s
        else:
            raise TypeError("not expecting type '%s'" % type(s)) 
Example #4
Source File: tokenization.py    From albert with Apache License 2.0 6 votes vote down vote up
def printable_text(text):
  """Returns text encoded in a way suitable for print or `tf.logging`."""

  # These functions want `str` for both Python2 and Python3, but in one case
  # it's a Unicode string and in the other it's a byte string.
  if six.PY3:
    if isinstance(text, str):
      return text
    elif isinstance(text, bytes):
      return six.ensure_text(text, "utf-8", "ignore")
    else:
      raise ValueError("Unsupported string type: %s" % (type(text)))
  elif six.PY2:
    if isinstance(text, str):
      return text
    elif isinstance(text, six.text_type):
      return six.ensure_binary(text, "utf-8")
    else:
      raise ValueError("Unsupported string type: %s" % (type(text)))
  else:
    raise ValueError("Not running on Python2 or Python 3?") 
Example #5
Source File: tokenization_utils.py    From Senta with Apache License 2.0 6 votes vote down vote up
def printable_text(text):
  """Returns text encoded in a way suitable for print or `tf.logging`."""

  # These functions want `str` for both Python2 and Python3, but in one case
  # it's a Unicode string and in the other it's a byte string.
  if six.PY3:
    if isinstance(text, str):
      return text
    elif isinstance(text, bytes):
      return six.ensure_text(text, "utf-8", "ignore")
    else:
      raise ValueError("Unsupported string type: %s" % (type(text)))
  elif six.PY2:
    if isinstance(text, str):
      return text
    elif isinstance(text, six.text_type):
      return six.ensure_binary(text, "utf-8")
    else:
      raise ValueError("Unsupported string type: %s" % (type(text)))
  else:
    raise ValueError("Not running on Python2 or Python 3?") 
Example #6
Source File: tokenization_test.py    From albert with Apache License 2.0 6 votes vote down vote up
def test_full_tokenizer(self):
    vocab_tokens = [
        "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
        "##ing", ","
    ]
    with tempfile.NamedTemporaryFile(delete=False) as vocab_writer:
      if six.PY2:
        vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
      else:
        contents = "".join([six.ensure_str(x) + "\n" for x in vocab_tokens])
        vocab_writer.write(six.ensure_binary(contents, "utf-8"))

      vocab_file = vocab_writer.name

    tokenizer = tokenization.FullTokenizer(vocab_file)
    os.unlink(vocab_file)

    tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
    self.assertAllEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])

    self.assertAllEqual(
        tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) 
Example #7
Source File: albert_tokenization.py    From bert-for-tf2 with MIT License 6 votes vote down vote up
def printable_text(text):
    """Returns text encoded in a way suitable for print or `tf.logging`."""

    # These functions want `str` for both Python2 and Python3, but in one case
    # it's a Unicode string and in the other it's a byte string.
    if six.PY3:
        if isinstance(text, str):
            return text
        elif isinstance(text, bytes):
            return six.ensure_text(text, "utf-8", "ignore")
        else:
            raise ValueError("Unsupported string type: %s" % (type(text)))
    elif six.PY2:
        if isinstance(text, str):
            return text
        elif isinstance(text, six.text_type):
            return six.ensure_binary(text, "utf-8")
        else:
            raise ValueError("Unsupported string type: %s" % (type(text)))
    else:
        raise ValueError("Not running on Python2 or Python 3?") 
Example #8
Source File: checkpointer.py    From lingvo with Apache License 2.0 6 votes vote down vote up
def RestoreGlobalStepIfNeeded(self, sess):
    """If global step is not initialized, load it from the checkpoint.

    Args:
      sess: tf.Session.
    """
    assert not self._save_only
    uninitialized_vars = self._GetUninitializedVarNames(sess)
    if six.ensure_binary('global_step') not in uninitialized_vars:
      return

    with sess.graph.as_default():
      gstep = py_utils.GetGlobalStep()

    path = tf.train.latest_checkpoint(self._train_dir)
    if path:
      reader = tf.train.NewCheckpointReader(path)
      value = reader.get_tensor('global_step')
      tf.logging.info('Restoring global step: %s', value)
      sess.run(gstep.assign(value))
    else:
      tf.logging.info('Initializing global step')
      sess.run(gstep.initializer) 
Example #9
Source File: utils.py    From nmt-wizard-docker with MIT License 6 votes vote down vote up
def md5files(files):
    """Computes the combined MD5 hash of multiple files, represented as a list
    of (key, path).
    """
    m = hashlib.md5()
    for key, path in sorted(files, key=lambda x: x[0]):
        m.update(six.ensure_binary(key))
        if os.path.isdir(path):
            sub_md5 = md5files([
                (os.path.join(key, filename), os.path.join(path, filename))
                for filename in os.listdir(path)
                if not filename.startswith('.')])
            m.update(six.ensure_binary(sub_md5))
        else:
            with open(path, 'rb') as f:
                m.update(f.read())
    return m.hexdigest() 
Example #10
Source File: tokenization.py    From embedding-as-service with MIT License 6 votes vote down vote up
def printable_text(text):
    """Returns text encoded in a way suitable for print or `tf.logging`."""

    # These functions want `str` for both Python2 and Python3, but in one case
    # it's a Unicode string and in the other it's a byte string.
    if six.PY3:
        if isinstance(text, str):
            return text
        elif isinstance(text, bytes):
            return six.ensure_text(text, "utf-8", "ignore")
        else:
            raise ValueError("Unsupported string type: %s" % (type(text)))
    elif six.PY2:
        if isinstance(text, str):
            return text
        elif isinstance(text, six.text_type):
            return six.ensure_binary(text, "utf-8")
        else:
            raise ValueError("Unsupported string type: %s" % (type(text)))
    else:
        raise ValueError("Not running on Python2 or Python 3?") 
Example #11
Source File: util.py    From scalyr-agent-2 with Apache License 2.0 6 votes vote down vote up
def json_encode(obj, output=None, binary=False):
    """Encodes an object into a JSON string.

    @param obj: The object to serialize
    @param output: If not None, a file-like object to which the serialization should be written.
    @param binary: If True return binary string, otherwise text string.
    @type obj: dict|list|six.text_type
    @type binary: bool
    """
    # 2->TODO encode json according to 'binary' flag.
    if binary:

        result = six.ensure_binary(_json_encode(obj, None))
        if output:
            output.write(result)
        else:
            return result
    else:
        return six.ensure_text(_json_encode(obj, output)) 
Example #12
Source File: test_compute_log_manager.py    From dagster with Apache License 2.0 6 votes vote down vote up
def test_compute_log_manager_from_config(s3_bucket):
    s3_prefix = 'foobar'

    dagster_yaml = '''
compute_logs:
  module: dagster_aws.s3.compute_log_manager
  class: S3ComputeLogManager
  config:
    bucket: "{s3_bucket}"
    local_dir: "/tmp/cool"
    prefix: "{s3_prefix}"
'''.format(
        s3_bucket=s3_bucket, s3_prefix=s3_prefix
    )

    with seven.TemporaryDirectory() as tempdir:
        with open(os.path.join(tempdir, 'dagster.yaml'), 'wb') as f:
            f.write(six.ensure_binary(dagster_yaml))

        instance = DagsterInstance.from_config(tempdir)
    assert instance.compute_log_manager._s3_bucket == s3_bucket  # pylint: disable=protected-access
    assert instance.compute_log_manager._s3_prefix == s3_prefix  # pylint: disable=protected-access 
Example #13
Source File: utils.py    From allura with Apache License 2.0 6 votes vote down vote up
def enc(self, plain, css_safe=False):
        '''Stupid fieldname encryption.  Not production-grade, but
        hopefully "good enough" to stop spammers.  Basically just an
        XOR of the spinner with the unobfuscated field name
        '''
        # Plain starts with its length, includes the ordinals for its
        #   characters, and is padded with random data

        # limit to plain ascii, which should be sufficient for field names
        # I don't think the logic below would work with non-ascii multi-byte text anyway
        plain.encode('ascii')

        plain = ([len(plain)]
                 + list(map(ord, plain))
                 + self.random_padding[:len(self.spinner_ord) - len(plain) - 1])
        enc = ''.join(six.unichr(p ^ s) for p, s in zip(plain, self.spinner_ord))
        enc = six.ensure_binary(enc)
        enc = self._wrap(enc)
        enc = six.ensure_text(enc)
        if css_safe:
            enc = ''.join(ch for ch in enc if ch.isalpha())
        return enc 
Example #14
Source File: utils.py    From allura with Apache License 2.0 6 votes vote down vote up
def make_spinner(self, timestamp=None):
        if timestamp is None:
            timestamp = self.timestamp
        try:
            self.client_ip = ip_address(self.request)
        except (TypeError, AttributeError):
            self.client_ip = '127.0.0.1'

        if not self.client_ip:
            # this is primarily for tests that sometimes don't have a remote_addr set on the request
            self.client_ip = '127.0.0.1'
        octets = self.client_ip.split('.')
        ip_chunk = '.'.join(octets[0:3])
        plain = '%d:%s:%s' % (
            timestamp, ip_chunk, tg.config.get('spinner_secret', 'abcdef'))
        return hashlib.sha1(six.ensure_binary(plain)).digest() 
Example #15
Source File: multifactor.py    From allura with Apache License 2.0 6 votes vote down vote up
def verify(self, totp, code, user):
        code = code.replace(' ', '')  # Google authenticator puts a space in their codes
        code = six.ensure_binary(code)  # can't be text

        self.enforce_rate_limit(user)

        # TODO prohibit re-use of a successful code, although it seems unlikely with a 30s window
        # see https://tools.ietf.org/html/rfc6238#section-5.2 paragraph 5

        # try the 1 previous time-window and current
        # per https://tools.ietf.org/html/rfc6238#section-5.2 paragraph 1
        windows = asint(config.get('auth.multifactor.totp.windows', 2))
        for time_window in range(windows):
            try:
                return totp.verify(code, time() - time_window*30)
            except InvalidToken:
                last_window = (time_window == windows - 1)
                if last_window:
                    raise 
Example #16
Source File: dsrf_report_manager_test.py    From dsrf with Apache License 2.0 6 votes vote down vote up
def test_parse_report_valid_not_human_readable(self):
    dsrf_xsd_file = path.join(
        path.dirname(__file__), '../testdata/sales-reporting-flat.xsd')
    avs_xsd_file = path.join(
        path.dirname(__file__), '../testdata/avs.xsd')
    files_list = [path.join(
        path.dirname(__file__), '../testdata/DSR_PADPIDA2014999999Z_'
        'PADPIDA2014111801Y_AdSupport_2015-02_AU_1of1_20150723T092522.tsv')]
    self.report_manager.parse_report(
        files_list, dsrf_xsd_file, avs_xsd_file,
        human_readable=False, write_head=False)
    serialized_block_str = open('/tmp/queue.txt', 'r').read().split(
        six.ensure_str('\n' + six.ensure_str(constants.QUEUE_DELIMITER)))[0]
    deserialized_block_str = six.ensure_binary(
        six.text_type(block_pb2.Block.FromString(serialized_block_str)),
        'utf-8')
    self.assertMultiLineEqual(BODY_BLOCK, deserialized_block_str) 
Example #17
Source File: dsrf_report_manager.py    From dsrf with Apache License 2.0 6 votes vote down vote up
def write_to_queue(self, block, logger, human_readable=False):
    """Writes the block object to the output queue.

    Override this if you want to change the queue form.

    Args:
      block: A block_pb2.Block object to write.
      logger: Logger object.
      human_readable: If True, write to the queue the block in a human readable
                      form. Otherwise, Write the block as a raw bytes.
    """
    output = None
    if human_readable:
      output = six.ensure_binary(six.text_type(block), 'utf8')
    else:
      output = block.SerializeToString()
    try:
      os.write(sys.stdout.fileno(), output)
      os.write(sys.stdout.fileno(),
               bytes('\n' + constants.QUEUE_DELIMITER + '\n'))
    except OSError as e:
      logger.exception('Could not write to queue: %s', e)
      sys.stderr.write(
          'WARNING: Parser interrupted. Some blocks were not parsed.\n')
      sys.exit(-1) 
Example #18
Source File: mfg_inspector.py    From openhtf with Apache License 2.0 6 votes vote down vote up
def __init__(self, user=None, keydata=None,
               token_uri=TOKEN_URI, destination_url=DESTINATION_URL):
    self.user = user
    self.keydata = keydata
    self.token_uri = token_uri
    self.destination_url = destination_url

    if user and keydata:
      self.credentials = oauth2client.client.SignedJwtAssertionCredentials(
          service_account_name=self.user,
          private_key=six.ensure_binary(self.keydata),
          scope=self.SCOPE_CODE_URI,
          user_agent='OpenHTF Guzzle Upload Client',
          token_uri=self.token_uri)
      self.credentials.set_store(_MemStorage())
    else:
      self.credentials = None

    self.upload_result = None

    self._cached_proto = None 
Example #19
Source File: metadata.py    From vessel-classification with Apache License 2.0 6 votes vote down vote up
def read_fishing_ranges(fishing_range_file):
    """ Read vessel fishing ranges, return a dict of id to classified fishing
        or non-fishing ranges for that vessel.
    """
    fishing_range_dict = defaultdict(lambda: [])
    with open(fishing_range_file, 'r') as f:
        for l in f.readlines()[1:]:
            els = l.split(',')
            id_ = six.ensure_binary(els[0].strip())
            start_time = parse_date(els[1]).replace(tzinfo=pytz.utc)
            end_time = parse_date(els[2]).replace(tzinfo=pytz.utc)
            is_fishing = float(els[3])
            fishing_range_dict[id_].append(
                FishingRange(start_time, end_time, is_fishing))

    return dict(fishing_range_dict) 
Example #20
Source File: tokenization.py    From ALBERT-TF2.0 with Apache License 2.0 6 votes vote down vote up
def printable_text(text):
  """Returns text encoded in a way suitable for print or `tf.logging`."""

  # These functions want `str` for both Python2 and Python3, but in one case
  # it's a Unicode string and in the other it's a byte string.
  if six.PY3:
    if isinstance(text, str):
      return text
    elif isinstance(text, bytes):
      return six.ensure_text(text, "utf-8", "ignore")
    else:
      raise ValueError("Unsupported string type: %s" % (type(text)))
  elif six.PY2:
    if isinstance(text, str):
      return text
    elif isinstance(text, six.text_type):
      return six.ensure_binary(text, "utf-8")
    else:
      raise ValueError("Unsupported string type: %s" % (type(text)))
  else:
    raise ValueError("Not running on Python2 or Python 3?") 
Example #21
Source File: oid_challenge_evaluation_utils_test.py    From models with Apache License 2.0 6 votes vote down vote up
def encode_mask(mask_to_encode):
  """Encodes a binary mask into the Kaggle challenge text format.

  The encoding is done in three stages:
   - COCO RLE-encoding,
   - zlib compression,
   - base64 encoding (to use as entry in csv file).

  Args:
    mask_to_encode: binary np.ndarray of dtype bool and 2d shape.

  Returns:
    A (base64) text string of the encoded mask.
  """
  mask_to_encode = np.squeeze(mask_to_encode)
  mask_to_encode = mask_to_encode.reshape(mask_to_encode.shape[0],
                                          mask_to_encode.shape[1], 1)
  mask_to_encode = mask_to_encode.astype(np.uint8)
  mask_to_encode = np.asfortranarray(mask_to_encode)
  encoded_mask = coco_mask.encode(mask_to_encode)[0]['counts']
  compressed_mask = zlib.compress(six.ensure_binary(encoded_mask),
                                  zlib.Z_BEST_COMPRESSION)
  base64_mask = base64.b64encode(compressed_mask)
  return base64_mask 
Example #22
Source File: add_context_to_examples_tf1_test.py    From models with Apache License 2.0 5 votes vote down vote up
def _create_second_tf_example(self):
    with self.test_session():
      encoded_image = tf.image.encode_jpeg(
          tf.constant(np.ones((4, 4, 3)).astype(np.uint8))).eval()

    example = tf.train.Example(features=tf.train.Features(feature={
        'image/encoded': BytesFeature(encoded_image),
        'image/source_id': BytesFeature(six.ensure_binary('image_id_2')),
        'image/height': Int64Feature(4),
        'image/width': Int64Feature(4),
        'image/object/class/label': Int64ListFeature([5]),
        'image/object/class/text': BytesListFeature([six.ensure_binary('hyena')
                                                    ]),
        'image/object/bbox/xmin': FloatListFeature([0.0]),
        'image/object/bbox/xmax': FloatListFeature([0.1]),
        'image/object/bbox/ymin': FloatListFeature([0.2]),
        'image/object/bbox/ymax': FloatListFeature([0.3]),
        'image/seq_id': BytesFeature(six.ensure_binary('01')),
        'image/seq_num_frames': Int64Feature(2),
        'image/seq_frame_num': Int64Feature(1),
        'image/date_captured': BytesFeature(
            six.ensure_binary(str(datetime.datetime(2020, 1, 1, 1, 1, 0)))),
        'image/embedding': FloatListFeature([0.4, 0.5, 0.6]),
        'image/embedding_score': FloatListFeature([0.9]),
        'image/embedding_length': Int64Feature(3)
    }))

    return example.SerializeToString() 
Example #23
Source File: task.py    From allura with Apache License 2.0 5 votes vote down vote up
def __call__(self, environ, context):
        # see TGController / CoreDispatcher for reference on how this works on a normal controllers

        task = environ['task']
        nocapture = environ['nocapture']
        result = task(restore_context=False, nocapture=nocapture)
        py_response = context.response
        py_response.headers['Content-Type'] = str('text/plain')  # `None` default is problematic for some middleware
        py_response.body = six.ensure_binary(result or b'')
        return py_response 
Example #24
Source File: metadata_test.py    From vessel-classification with Apache License 2.0 5 votes vote down vote up
def test_fixed_time_reader(self):
        parsed_lines = csv.DictReader(self.raw_lines)
        available_vessels = set(six.ensure_binary(str(x)) for x in range(100001, 100014))
        result = metadata.read_vessel_time_weighted_metadata_lines(
            available_vessels, parsed_lines, self.fishing_range_dict,
            'Test')

        self.assertEqual(1.0, result.vessel_weight(b'100001'))
        self.assertEqual(1.0, result.vessel_weight(b'100002'))
        self.assertEqual(3.0, result.vessel_weight(b'100009'))
        self.assertEqual(0.0, result.vessel_weight(b'100012'))

        self._check_splits(result) 
Example #25
Source File: tokenization_spm.py    From Senta with Apache License 2.0 5 votes vote down vote up
def encode_pieces(sp_model, text, return_unicode=True, sample=False):
  """turn sentences into word pieces."""

  # liujiaxiang: add for ernie-albert, mainly consider for “/”/‘/’/— causing too many unk
  text = clean_text(text)

  if six.PY2 and isinstance(text, six.text_type):
    text = six.ensure_binary(text, "utf-8")

  if not sample:
    pieces = sp_model.EncodeAsPieces(text)
  else:
    pieces = sp_model.SampleEncodeAsPieces(text, 64, 0.1)

  new_pieces = []
  for piece in pieces:
    piece = printable_text(piece)
    if len(piece) > 1 and piece[-1] == "," and piece[-2].isdigit():
      cur_pieces = sp_model.EncodeAsPieces(
          six.ensure_binary(piece[:-1]).replace(SPIECE_UNDERLINE, b""))
      if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE:
        if len(cur_pieces[0]) == 1:
          cur_pieces = cur_pieces[1:]
        else:
          cur_pieces[0] = cur_pieces[0][1:]
      cur_pieces.append(piece[-1])
      new_pieces.extend(cur_pieces)
    else:
      new_pieces.append(piece)

  # note(zhiliny): convert back to unicode for py2
  if six.PY2 and return_unicode:
    ret_pieces = []
    for piece in new_pieces:
      if isinstance(piece, str):
        piece = six.ensure_text(piece, "utf-8")
      ret_pieces.append(piece)
    new_pieces = ret_pieces

  return new_pieces 
Example #26
Source File: tokenization.py    From models with Apache License 2.0 5 votes vote down vote up
def encode_pieces(sp_model, text, sample=False):
  """Segements text into pieces.

  This method is used together with sentence piece tokenizer and is forked from:
  https://github.com/google-research/google-research/blob/master/albert/tokenization.py


  Args:
    sp_model: A spm.SentencePieceProcessor object.
    text: The input text to be segemented.
    sample: Whether to randomly sample a segmentation output or return a
      deterministic one.

  Returns:
    A list of token pieces.
  """
  if six.PY2 and isinstance(text, six.text_type):
    text = six.ensure_binary(text, "utf-8")

  if not sample:
    pieces = sp_model.EncodeAsPieces(text)
  else:
    pieces = sp_model.SampleEncodeAsPieces(text, 64, 0.1)
  new_pieces = []
  for piece in pieces:
    piece = printable_text(piece)
    if len(piece) > 1 and piece[-1] == "," and piece[-2].isdigit():
      cur_pieces = sp_model.EncodeAsPieces(piece[:-1].replace(
          SPIECE_UNDERLINE, ""))
      if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE:
        if len(cur_pieces[0]) == 1:
          cur_pieces = cur_pieces[1:]
        else:
          cur_pieces[0] = cur_pieces[0][1:]
      cur_pieces.append(piece[-1])
      new_pieces.extend(cur_pieces)
    else:
      new_pieces.append(piece)

  return new_pieces 
Example #27
Source File: static.py    From allura with Apache License 2.0 5 votes vote down vote up
def tool_icon_css(self, *args, **kw):
        """
        Serve stylesheet containing icon urls for every installed tool.

        If you want to use this, include it in your theme like:
            g.register_css('/nf/tool_icon_css?' + g.build_key, compress=False)

        """
        css, md5 = g.tool_icon_css
        return utils.serve_file(
            BytesIO(six.ensure_binary(css)), 'tool_icon_css', 'text/css', etag=md5) 
Example #28
Source File: tf_record_creation_util_test.py    From models with Apache License 2.0 5 votes vote down vote up
def test_sharded_tfrecord_writes(self):
    with contextlib2.ExitStack() as tf_record_close_stack:
      output_tfrecords = tf_record_creation_util.open_sharded_output_tfrecords(
          tf_record_close_stack,
          os.path.join(tf.test.get_temp_dir(), 'test.tfrec'), 10)
      for idx in range(10):
        output_tfrecords[idx].write(six.ensure_binary('test_{}'.format(idx)))

    for idx in range(10):
      tf_record_path = '{}-{:05d}-of-00010'.format(
          os.path.join(tf.test.get_temp_dir(), 'test.tfrec'), idx)
      records = list(tf.python_io.tf_record_iterator(tf_record_path))
      self.assertAllEqual(records, ['test_{}'.format(idx).encode('utf-8')]) 
Example #29
Source File: add_context_to_examples_tf1_test.py    From models with Apache License 2.0 5 votes vote down vote up
def test_beam_pipeline_sequence_example(self):
    with InMemoryTFRecord(
        [self._create_first_tf_example(),
         self._create_second_tf_example()]) as input_tfrecord:
      temp_dir = tempfile.mkdtemp(dir=os.environ.get('TEST_TMPDIR'))
      output_tfrecord = os.path.join(temp_dir, 'output_tfrecord')
      sequence_key = six.ensure_binary('image/seq_id')
      max_num_elements = 10
      num_shards = 1
      pipeline_options = beam.options.pipeline_options.PipelineOptions(
          runner='DirectRunner')
      p = beam.Pipeline(options=pipeline_options)
      add_context_to_examples.construct_pipeline(
          p,
          input_tfrecord,
          output_tfrecord,
          sequence_key,
          max_num_elements_in_context_features=max_num_elements,
          num_shards=num_shards,
          output_type='tf_sequence_example')
      p.run()
      filenames = tf.io.gfile.glob(output_tfrecord + '-?????-of-?????')
      actual_output = []
      record_iterator = tf.python_io.tf_record_iterator(
          path=filenames[0])
      for record in record_iterator:
        actual_output.append(record)
      self.assertEqual(len(actual_output), 1)
      self.assert_expected_sequence_example(
          [tf.train.SequenceExample.FromString(
              tf_example) for tf_example in actual_output]) 
Example #30
Source File: test_record.py    From openhtf with Apache License 2.0 5 votes vote down vote up
def __init__(self, data, mimetype):
    data = six.ensure_binary(data)
    self.mimetype = mimetype
    self.sha1 = hashlib.sha1(data).hexdigest()
    self._file = self._create_temp_file(data)