主要转化语句为 tensor_dict = {k: tf.constant(v) for k, v in np_example.items() if k in features_metadata}。增加了特征名称的选择,不同特征维度,特征数的判断等。
from typing import Dict, Tuple, Sequence, Union, Mapping, Optional#import tensorflow.compat.v1 as tfimport tensorflow as tfimport numpy as npimport pickle# Type aliases.FeaturesMetadata = Dict[str, Tuple[tf.dtypes.DType, Sequence[Union[str, int]]]]#FeatureDict = Mapping[str, np.ndarray]TensorDict = Dict[str, tf.Tensor]NUM_RES = 'num residues placeholder'NUM_TEMPLATES = 'num templates placeholder'NUM_SEQ = "length msa placeholder"atom_type_num = 37FEATURES = {#### Static features of a protein sequence ####"aatype": (tf.float32, [NUM_RES, 21]),"between_segment_residues": (tf.int64, [NUM_RES, 1]),"deletion_matrix": (tf.float32, [NUM_SEQ, NUM_RES, 1]),"domain_name": (tf.string, [1]),"msa": (tf.int64, [NUM_SEQ, NUM_RES, 1]),"num_alignments": (tf.int64, [NUM_RES, 1]),"residue_index": (tf.int64, [NUM_RES, 1]),"seq_length": (tf.int64, [NUM_RES, 1]),"sequence": (tf.string, [1]),"all_atom_positions": (tf.float32, [NUM_RES, atom_type_num, 3]),"all_atom_mask": (tf.int64, [NUM_RES, atom_type_num]),"resolution": (tf.float32, [1]),"template_domain_names": (tf.string, [NUM_TEMPLATES]),"template_sum_probs": (tf.float32, [NUM_TEMPLATES, 1]),"template_aatype": (tf.float32, [NUM_TEMPLATES, NUM_RES, 22]),"template_all_atom_positions": (tf.float32, [NUM_TEMPLATES, NUM_RES, atom_type_num, 3]),"template_all_atom_masks": (tf.float32, [NUM_TEMPLATES, NUM_RES, atom_type_num, 1]),}def _make_features_metadata(feature_names: Sequence[str]) -> FeaturesMetadata:"""Makes a feature name to type and shape mapping from a list of names."""# Make sure these features are always read.required_features = ["aatype", "sequence", "seq_length"]feature_names = list(set(feature_names) | set(required_features))features_metadata = {name: FEATURES[name] for name in feature_names}return features_metadatadef np_to_tensor_dict(np_example: Mapping[str, np.ndarray],features: Sequence[str],) -> TensorDict:"""Creates dict of tensors from a dict of NumPy arrays.Args:np_example: A dict of NumPy feature arrays.features: A list of strings of feature names to be returned in the dataset.Returns:A dictionary of features mapping feature names to features. Only the givenfeatures are returned, all other ones are filtered out."""features_metadata = _make_features_metadata(features)print(f"features_metadata:{features_metadata}")tensor_dict = {k: tf.constant(v) for k, v in np_example.items() if k in features_metadata}#print(f"tensor_dict:{tensor_dict}")# Ensures shapes are as expected. Needed for setting size of empty features# e.g. when no template hits were found.tensor_dict = parse_reshape_logic(tensor_dict, features_metadata)return tensor_dictdef protein_features_shape(feature_name: str, num_residues: int, msa_length: int, num_templates: Optional[int] = None, features: Optional[FeaturesMetadata] = None):"""Get the shape for the given feature name.This is near identical to _get_tf_shape_no_placeholders() but with 2differences:* This method does not calculate a single placeholder from the total number ofelements (eg given and size := 12, this won't deduce NUM_RESmust be 4)* This method will work with tensorsArgs:feature_name: String identifier for the feature. If the feature name endswith "_unnormalized", this suffix is stripped off.num_residues: The number of residues in the current domain - some elementsof the shape can be dynamic and will be replaced by this value.msa_length: The number of sequences in the multiple sequence alignment, someelements of the shape can be dynamic and will be replaced by this value.If the number of alignments is unknown / not read, please pass None formsa_length.num_templates (optional): The number of templates in this tfexample.features: A feature_name to (tf_dtype, shape) lookup; defaults to FEATURES.Returns:List of ints representation the tensor size.Raises:ValueError: If a feature is requested but no concrete placeholder value isgiven."""features = features or FEATURESif feature_name.endswith("_unnormalized"):feature_name = feature_name[:-13]# features是FeaturesMetadata数据结构# FeaturesMetadata = Dict[str, Tuple[tf.dtypes.DType, Sequence[Union[str, int]]]]unused_dtype, raw_sizes = features[feature_name]#print(f"feature_name:{feature_name}")#print(f"features value:{features[feature_name]}") #print(f"features[feature_name]:{features[feature_name]}")#print(f"unused_dtype:{unused_dtype}")#print(f"raw_sizes:{raw_sizes}"replacements = {NUM_RES: num_residues,NUM_SEQ: msa_length}if num_templates is not None:replacements[NUM_TEMPLATES] = num_templates# my_dict.get(key, default_value)sizes = [replacements.get(dimension, dimension) for dimension in raw_sizes]for dimension in sizes:if isinstance(dimension, str):raise ValueError("Could not parse %s (shape: %s) with values: %s" % (feature_name, raw_sizes, replacements))return sizesdef parse_reshape_logic(parsed_features: TensorDict,features: FeaturesMetadata,key: Optional[str] = None) -> TensorDict:"""Transforms parsed serial features to the correct shape."""# Find out what is the number of sequences and the number of alignments.num_residues = tf.cast(_first(parsed_features["seq_length"]), dtype=tf.int32)if "num_alignments" in parsed_features:num_msa = tf.cast(_first(parsed_features["num_alignments"]), dtype=tf.int32)else:num_msa = 0if "template_domain_names" in parsed_features:num_templates = tf.cast(tf.shape(parsed_features["template_domain_names"])[0], dtype=tf.int32)else:num_templates = 0if key is not None and "key" in features:parsed_features["key"] = [key]# Expand dims from () to (1,).# Reshape the tensors according to the sequence length and num alignments.for k, v in parsed_features.items():new_shape = protein_features_shape(feature_name=k,num_residues=num_residues,msa_length=num_msa,num_templates=num_templates,features=features)#print(f"new_shape:{new_shape}")new_shape_size = tf.constant(1, dtype=tf.int32)for dim in new_shape:new_shape_size *= tf.cast(dim, tf.int32)#print(f"new_shape_size:{new_shape_size}")#print(f"original_shape_size:{ tf.size(v)}")# 断言函数,用于检查两个张量是否相等。不相等引发异常assert_equal = tf.assert_equal(tf.size(v), new_shape_size,name="assert_%s_shape_correct" % k,message="The size of feature %s (%s) could not be reshaped ""into %s" % (k, tf.size(v), new_shape))if "template" not in k:# Make sure the feature we are reshaping is not empty.assert_non_empty = tf.assert_greater(tf.size(v), 0, name="assert_%s_non_empty" % k,message="The feature %s is not set in the tf.Example. Either do not ""request the feature or use a tf.Example that has the ""feature set." % k)with tf.control_dependencies([assert_non_empty, assert_equal]):parsed_features[k] = tf.reshape(v, new_shape, name="reshape_%s" % k)else:with tf.control_dependencies([assert_equal]):parsed_features[k] = tf.reshape(v, new_shape, name="reshape_%s" % k)return parsed_featuresdef _first(tensor: tf.Tensor) -> tf.Tensor:"""Returns the 1st element - the input can be a tensor or a scalar."""return tf.reshape(tensor, shape=(-1,))[0] # 将其转换为一维数组## 读入FeatureDict列表with open("HBB_features_lst.pkl", 'rb') as f:HBB_features_lst = pickle.load(f)Human_HBB_feature_dict = HBB_features_lst[0]print(Human_HBB_feature_dict.keys())#print(Human_HBB_feature_dict['num_alignments'])features = FEATURES.keys()#for key in Human_HBB_feature_dict.keys():#if key not in features:#print(key)#print(features)Human_HBB_tensor_dict = np_to_tensor_dict(Human_HBB_feature_dict,features= features)print(Human_HBB_tensor_dict.keys())#print(Human_HBB_tensor_dict)#print(Human_HBB_tensor_dict["template_domain_names"])