from kaggle_secrets import UserSecretsClient #kaggle 可忽略import wandb#####user_secrets = UserSecretsClient() #### kagglesecret_value_0 = user_secrets.get_secret("wandb_key")### kaggle,此次为wandb_apiwandb.login(key=secret_value_0) #####初始化from wandb.keras import WandbCallback, WandbMetricsLoggerrun = wandb.init(project = 'open_problems',#项目名称,自动创建 save_code = True, name='tabtransformer' )####中间插入代码####tabTransformer = TabTransformer(categories = nu, # number of unique elements in each categorical featurenum_continuous = 5,# number of numerical featuresdim = 16,# embedding/transformer dimensiondim_out = 35, # dimension of the model outputdepth = 6, # number of transformer layers in the stackheads = 8, # number of attention headsattn_dropout = 0.1,# attention layer dropout in transformersff_dropout = 0.1,# feed-forward layer dropout in transformersmlp_hidden = [(32, 'relu'), (16, 'relu')] # mlp layer dimensions and activations)tabTransformer.compile(Adam(0.001),'mae',metrics=['mae'])tabTransformer.fit(X_train,y_train,validation_data=(X_val,y_val),batch_size=32,epochs=30,callbacks=[WandbMetricsLogger()])################run.finish()#运行结束
参考[Keras]TabTransformer+W&B | Kaggle