torch.clamp()
函数用于对输入张量进行截断操作,将张量中的每个元素限制在指定的范围内。
其语法为:
torch.clamp(input, min, max, out=None) -> Tensor
其中,参数的含义如下:
input
:输入张量。min
:张量中的最小值。如果为None
,则表示不对最小值进行限制。max
:张量中的最大值。如果为None
,则表示不对最大值进行限制。out
:输出张量。
torch.clamp()
函数返回一个新的张量,其中每个元素都被截断在[min, max]
的范围内。如果min
或max
为None
,则对应的限制条件被忽略。
下面是一个使用torch.clamp()
函数的示例:
import torchx = torch.randn(2, 3)print(x)y = torch.clamp(x, min=-0.5, max=0.5)print(y)
输出结果为:
tensor([[-0.3138, -0.1604, -0.4374],[-1.0861, -0.2837,1.1688]])tensor([[-0.3138, -0.1604, -0.4374],[-0.5000, -0.2837,0.5000]])
可以看到,torch.clamp()
函数将x
张量中的元素限制在了[-0.5, 0.5]
的范围内。