Python torch.is_floating_point() Examples
The following are 14
code examples of torch.is_floating_point().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
torch
, or try the search function
.
Example #1
Source File: functions.py From seamseg with BSD 3-Clause "New" or "Revised" License | 6 votes |
def forward(ctx, x, bbx, idx, roi_size, interpolation, padding, valid_mask): ctx.save_for_backward(bbx, idx) ctx.input_shape = (x.size(0), x.size(2), x.size(3)) ctx.valid_mask = valid_mask try: ctx.interpolation = _INTERPOLATION[interpolation] except KeyError: raise ValueError("Unknown interpolation {}".format(interpolation)) try: ctx.padding = _PADDING[padding] except KeyError: raise ValueError("Unknown padding {}".format(padding)) y, mask = _backend.roi_sampling_forward(x, bbx, idx, roi_size, ctx.interpolation, ctx.padding, valid_mask) if not torch.is_floating_point(x): ctx.mark_non_differentiable(y) if valid_mask: ctx.mark_non_differentiable(mask) return y, mask else: return y
Example #2
Source File: misc.py From torchdiffeq with MIT License | 6 votes |
def _check_inputs(func, y0, t): tensor_input = False if torch.is_tensor(y0): tensor_input = True y0 = (y0,) _base_nontuple_func_ = func func = lambda t, y: (_base_nontuple_func_(t, y[0]),) assert isinstance(y0, tuple), 'y0 must be either a torch.Tensor or a tuple' for y0_ in y0: assert torch.is_tensor(y0_), 'each element must be a torch.Tensor but received {}'.format(type(y0_)) if _decreasing(t): t = -t _base_reverse_func = func func = lambda t, y: tuple(-f_ for f_ in _base_reverse_func(-t, y)) for y0_ in y0: if not torch.is_floating_point(y0_): raise TypeError('`y0` must be a floating point Tensor but is a {}'.format(y0_.type())) if not torch.is_floating_point(t): raise TypeError('`t` must be a floating point Tensor but is a {}'.format(t.type())) return tensor_input, func, y0, t
Example #3
Source File: scatter.py From pytorch_scatter with MIT License | 6 votes |
def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1, out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None) -> torch.Tensor: out = scatter_sum(src, index, dim, out, dim_size) dim_size = out.size(dim) index_dim = dim if index_dim < 0: index_dim = index_dim + src.dim() if index.dim() <= index_dim: index_dim = index.dim() - 1 ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) count = scatter_sum(ones, index, index_dim, None, dim_size) count.clamp_(1) count = broadcast(count, out, dim) if torch.is_floating_point(out): out.true_divide_(count) else: out.floor_divide_(count) return out
Example #4
Source File: softmax.py From pytorch_scatter with MIT License | 6 votes |
def scatter_softmax(src: torch.Tensor, index: torch.Tensor, dim: int = -1, eps: float = 1e-12) -> torch.Tensor: if not torch.is_floating_point(src): raise ValueError('`scatter_softmax` can only be computed over tensors ' 'with floating point data types.') index = broadcast(index, src, dim) max_value_per_index = scatter_max(src, index, dim=dim)[0] max_per_src_element = max_value_per_index.gather(dim, index) recentered_scores = src - max_per_src_element recentered_scores_exp = recentered_scores.exp() sum_per_index = scatter_sum(recentered_scores_exp, index, dim) normalizing_constants = sum_per_index.add_(eps).gather(dim, index) return recentered_scores_exp.div(normalizing_constants)
Example #5
Source File: softmax.py From pytorch_scatter with MIT License | 6 votes |
def scatter_log_softmax(src: torch.Tensor, index: torch.Tensor, dim: int = -1, eps: float = 1e-12) -> torch.Tensor: if not torch.is_floating_point(src): raise ValueError('`scatter_log_softmax` can only be computed over ' 'tensors with floating point data types.') index = broadcast(index, src, dim) max_value_per_index = scatter_max(src, index, dim=dim)[0] max_per_src_element = max_value_per_index.gather(dim, index) recentered_scores = src - max_per_src_element sum_per_index = scatter_sum(recentered_scores.exp(), index, dim) normalizing_constants = sum_per_index.add_(eps).log_().gather(dim, index) return recentered_scores.sub_(normalizing_constants)
Example #6
Source File: misc.py From occupancy_flow with MIT License | 6 votes |
def _check_inputs(func, y0, t, f_options): tensor_input = False if torch.is_tensor(y0): tensor_input = True y0 = (y0,) _base_nontuple_func_ = func func = lambda t, y, **f_options: (_base_nontuple_func_(t, y[0], **f_options),) assert isinstance(y0, tuple), 'y0 must be either a torch.Tensor or a tuple' for y0_ in y0: assert torch.is_tensor(y0_), 'each element must be a torch.Tensor but received {}'.format(type(y0_)) if _decreasing(t): t = -t _base_reverse_func = func func = lambda t, y, **f_options: tuple(-f_ for f_ in _base_reverse_func(-t, y, **f_options)) for y0_ in y0: if not torch.is_floating_point(y0_): raise TypeError('`y0` must be a floating point Tensor but is a {}'.format(y0_.type())) if not torch.is_floating_point(t): raise TypeError('`t` must be a floating point Tensor but is a {}'.format(t.type())) return tensor_input, func, y0, t
Example #7
Source File: data_schemas.py From lale with Apache License 2.0 | 6 votes |
def torch_tensor_to_schema(tensor): assert torch_installed, """Your Python environment does not have torch installed. You can install it with pip install torch or with pip install 'lale[full]'""" assert isinstance(tensor, torch.Tensor) #https://pytorch.org/docs/stable/tensor_attributes.html#torch-dtype if tensor.dtype == torch.bool: result = {'type': 'boolean'} elif tensor.dtype == torch.uint8: result = {'type': 'integer', 'minimum': 0, 'maximum': 255} elif torch.is_floating_point(tensor): result = {'type': 'number'} else: result = {'type': 'integer'} for dim in reversed(tensor.shape): result = { 'type': 'array', 'minItems': dim, 'maxItems': dim, 'items': result} return result
Example #8
Source File: tensor.py From pytorch_sparse with MIT License | 6 votes |
def is_floating_point(self) -> bool: value = self.storage.value() return torch.is_floating_point(value) if value is not None else True
Example #9
Source File: functions.py From seamseg with BSD 3-Clause "New" or "Revised" License | 5 votes |
def backward(ctx, *args): if ctx.valid_mask: dy, _ = args else: dy = args[0] assert torch.is_floating_point(dy), "ROISampling.backward is only defined for floating point types" bbx, idx = ctx.saved_tensors dx = _backend.roi_sampling_backward(dy, bbx, idx, ctx.input_shape, ctx.interpolation, ctx.padding) return dx, None, None, None, None, None, None
Example #10
Source File: base_trainer.py From packnet-sfm with MIT License | 5 votes |
def sample_to_cuda(data, dtype=None): if isinstance(data, str): return data elif isinstance(data, dict): return {key: sample_to_cuda(data[key], dtype) for key in data.keys()} elif isinstance(data, list): return [sample_to_cuda(val, dtype) for val in data] else: # only convert floats (e.g., to half), otherwise preserve (e.g, ints) dtype = dtype if torch.is_floating_point(data) else None return data.to('cuda', dtype=dtype)
Example #11
Source File: logsumexp.py From pytorch_scatter with MIT License | 5 votes |
def scatter_logsumexp(src: torch.Tensor, index: torch.Tensor, dim: int = -1, out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None, eps: float = 1e-12) -> torch.Tensor: if not torch.is_floating_point(src): raise ValueError('`scatter_logsumexp` can only be computed over ' 'tensors with floating point data types.') index = broadcast(index, src, dim) if out is not None: dim_size = out.size(dim) else: if dim_size is None: dim_size = int(index.max()) + 1 size = list(src.size()) size[dim] = dim_size max_value_per_index = torch.full(size, float('-inf'), dtype=src.dtype, device=src.device) scatter_max(src, index, dim, max_value_per_index, dim_size=dim_size)[0] max_per_src_element = max_value_per_index.gather(dim, index) recentered_score = src - max_per_src_element recentered_score.masked_fill_(torch.isnan(recentered_score), float('-inf')) if out is not None: out = out.sub(max_per_src_element).exp() sum_per_index = scatter_sum(recentered_score.exp_(), index, dim, out, dim_size) return sum_per_index.add_(eps).log_().add_(max_value_per_index)
Example #12
Source File: compat.py From apex with BSD 3-Clause "New" or "Revised" License | 5 votes |
def is_tensor_like(x): return torch.is_tensor(x) or isinstance(x, torch.autograd.Variable) # Wraps `torch.is_floating_point` if present, otherwise checks # the suffix of `x.type()`.
Example #13
Source File: compat.py From apex with BSD 3-Clause "New" or "Revised" License | 5 votes |
def is_floating_point(x): if hasattr(torch, 'is_floating_point'): return torch.is_floating_point(x) try: torch_type = x.type() return torch_type.endswith('FloatTensor') or \ torch_type.endswith('HalfTensor') or \ torch_type.endswith('DoubleTensor') except AttributeError: return False
Example #14
Source File: technology.py From torchfunc with MIT License | 4 votes |
def _analyse(self, module, function: str): def _correct_types(data, submodule, index, is_float: bool): correct_types = self.float_types if is_float else self.integer_types if not any( correct_type == submodule.weight.dtype for correct_type in correct_types ): data["type"]["float" if is_float else "integer"].append(index) def _correct_shapes( data, submodule, index, attributes, attribute_name, is_float: bool ): for attribute in attributes[type(submodule)]: if hasattr(submodule, attribute): shape = getattr(submodule, attribute) correct = shape % (8 if is_float else 16) == 0 if not correct: data["shape"]["float" if is_float else "integer"][ attribute_name ].append(index) def _find_problems(data, submodule, index, is_float: bool): def _operation_problems(operation: str): for entry in ("inputs", "outputs"): _correct_shapes( data, submodule, index, getattr(self, operation + "_" + entry), entry, is_float, ) _correct_types(data, submodule, index, is_float) if isinstance(submodule, self.linear_types): _operation_problems("linear") elif isinstance(submodule, self.convolution_types): _operation_problems("convolution") ####################################################################### # # MAIN FUNCTION # ####################################################################### data = { "type": {"float": [], "integer": []}, "shape": { "float": {"inputs": [], "outputs": []}, "integer": {"inputs": [], "outputs": []}, }, } for index, submodule in enumerate(getattr(module, function)()): if hasattr(submodule, "weight"): if torch.is_floating_point(submodule.weight): _find_problems(data, submodule, index, is_float=True) else: _find_problems(data, submodule, index, is_float=False) return data