PAI平臺提供圖像多標簽分類相關算法,支持千萬級別超大規模的圖片樣本訓練。本文為您介紹如何使用PAI命令基于圖片數據生成圖像多標簽分類模型。
圖像分類訓練
您可以使用SQL腳本組件進行PAI命令調用,也可以使用MaxCompute客戶端或DataWorks的開發節點進行PAI命令調用。如何使用MaxCompute客戶端和創建DataWorks的開發節點,詳情請參見使用本地客戶端(odpscmd)連接或創建并管理MaxCompute節點。
圖像單標簽分類單機訓練
pai -name easy_vision_ext -Dbuckets='oss://{bucket_name}.{oss_host}/{path}' -Darn='acs:ram::*********:role/aliyunodpspaidefaultrole' -DossHost='{oss_host}' -DgpuRequired=100 -Dcmd train -Dparam_config '--model_type Classification --backbone inception_v4 --num_classes 10 --num_epochs 1 --model_dir oss://examplebucket/test/cifar_inception_v4 --use_pretrained_model true --train_data oss://examplebucket/data/test/cifar10/*.tfrecord --test_data oss://examplebucket/data/test/cifar10/*.tfrecord --num_test_example 20 --train_batch_size 32 --test_batch_size=32 --image_size 299 --initial_learning_rate 0.01 --staircase true'
圖像單標簽分類多機訓練
pai -name easy_vision_ext -Dbuckets='oss://{bucket_name}.{oss_host}/{path}' -Darn='acs:ram::*********:role/aliyunodpspaidefaultrole' -DossHost='{oss_host}' -Dcmd train -Dcluster='{ \"ps\": { \"count\" : 1, \"cpu\" : 600 }, \"worker\" : { \"count\" : 3, \"cpu\" : 800, \"gpu\" : 100 } }' -Dparam_config='--model_type Classification --backbone inception_v4 --num_classes 10 --num_epochs 1 --model_dir oss://examplebucket/test/cifar_inception_v4_dis --use_pretrained_model true --train_data oss://examplebucket/data/test/cifar10/*.tfrecord --test_data oss://examplebucket/data/test/cifar10/*.tfrecord --num_test_example 20 --train_batch_size 32 --test_batch_size=32 --image_size 299 --initial_learning_rate 0.01 --staircase true'
圖像多標簽單機訓練
pai -name easy_vision_ext -Dbuckets='oss://{bucket_name}.{oss_host}/{path}' -Darn='acs:ram::*********:role/aliyunodpspaidefaultrole' -DossHost='{oss_host}' -DgpuRequired=100 -Dcmd train -Dparam_config '--model_type MultiLabelClassification --backbone inception_v4 --num_classes 10 --num_epochs 1 --model_dir oss://examplebucket/test/cifar_inception_v4 --use_pretrained_model true --train_data oss://examplebucket/data/test/cifar10/*.tfrecord --test_data oss://examplebucket/data/test/cifar10/*.tfrecord --num_test_example 20 --train_batch_size 32 --test_batch_size=32 --image_size 299 --initial_learning_rate 0.01 --staircase true'
圖像多標簽多機訓練
pai -name easy_vision_ext -Dbuckets='oss://{bucket_name}.{oss_host}/{path}' -Darn='acs:ram::*********:role/aliyunodpspaidefaultrole' -DossHost='{oss_host}' -Dcmd train -Dcluster='{ \"ps\": { \"count\" : 1, \"cpu\" : 600 }, \"worker\" : { \"count\" : 3, \"cpu\" : 800, \"gpu\" : 100 } }' -Dparam_config='--model_type MultiLabelClassification --backbone inception_v4 --num_classes 10 --num_epochs 1 --model_dir oss://examplebucket/test/cifar_inception_v4_dis --use_pretrained_model true --train_data oss://examplebucket/data/test/cifar10/*.tfrecord --test_data oss://examplebucket/data/test/cifar10/*.tfrecord --num_test_example 20 --train_batch_size 32 --test_batch_size=32 --image_size 299 --initial_learning_rate 0.01 --staircase true'
參數說明
參數 | 是否必選 | 描述 | 取值格式 | 默認值 |
buckets | 是 | OSS Bucket地址。Bucket必須以正斜線(/)結尾。 | oss://{bucket_name}.{oss_host}/{path} | 無 |
arn | 是 | 訪問OSS的授權。您可以登錄PAI控制臺,在全部產品依賴頁面的Designer區域,單擊操作列下的查看授權信息,獲取arn,具體操作請參見云產品依賴與授權:Designer。 | acs:ram::*:role/AliyunODPSPAIDefaultRole | 無 |
ossHost | 否 | OSS訪問域名,詳情請參見訪問域名和數據中心。如果未指定該參數,則從Buckets參數中獲取。 | oss-{region}.aliyuncs.com | 從Buckets參數中獲取 |
cluster | 否 | 分布式訓練參數相關配置。 | JSON格式字符串 | “” |
gpuRequired | 否 | 標識是否使用GPU,默認使用一張卡。如果取值200,則一個Worker申請2張卡。 | 100 | 100 |
cmd | 是 | EasyVision任務類型。模型訓練時,該參數應取值為train。 | train | 無 |
param_config | 是 | 模型訓練參數,其格式與Python Argparser參數格式一致,詳情請參見param_config說明。 | STRING | 無 |
param_config說明
param_config包含若干模型配置相關參數,格式為Python Argparser,示例如下。
-Dparam_config = '--model_type MultiLabelClassification --backbone inception_v4 --num_classes 200 --model_dir oss://your/bucket/exp_dir'
所有字符串類型的參數,其取值均不加引號。
參數名稱 | 是否必選 | 參數描述 | 取值格式 | 默認值 |
model_type | 是 | 訓練模型類型。多標簽分類的模型類型為MultiLabelClassification。 | STRING | 無 |
backbone | 否 | 識別模型的網絡名稱,取值包括:
| STRING | inception_v4 |
num_classes | 是 | 分類類別數量。 | 100 | 無 |
image_size | 否 | 圖片Resize后的大小,單位為像素。 | INT | 224 |
use_crop | 否 | 是否使用crop進行數據增強。 | BOOL | true |
eval_each_category | 否 | 是否針對每個類別單獨進行評估。 | BOOL | false |
optimizer | 否 | 優化方法,取值包括:
| STRING | momentum |
lr_type | 否 | 學習率調整策略,取值包括:
| STRING | exponential_decay |
initial_learning_rate | 否 | 初始學習率。 | FLOAT | 0.01 |
decay_epochs | 否 | 如果使用exponential_decay,該參數對應tf.train.exponential.decay中的decay_steps,系統會自動根據訓練數據總數將decay_epochs轉換為decay_steps。例如,取值為10,通常是總Epoch數的1/2。 如果使用manual_step,該參數表示需要調整學習率的迭代輪數。例如16 18表示在16 Epoch和18 Epoch對學習率進行調整。通常將這兩個值配置為總Epoch的8/10和9/10。 | 整數列表,例如20 20 40 60。 | 20 |
decay_factor | 否 | tf.train.exponential.decay中的decay_factor。 | FLOAT | 0.95 |
staircase | 否 | tf.train.exponential.decay中的staircase。 | BOOL | true |
power | 否 | tf.train.polynomial.decay中的power。 | FLOAT | 0.9 |
learning_rates | 否 | manual_step學習率調整策略中使用的參數,表示在指定Epoch中學習率的取值。 如果您指定的調整Epoch有兩個,則需要在此指定兩個Epoch對應的學習率。例如,如果decay_epochs為20 40,則該將參數配置為0.001 0.0001,表示在20 Epoch學習率調整為0.001,40 Epoch學習率調整為0.0001。建議幾次調整的學習率依次為初始學習率的1/10、1/100及1/1000。 | 浮點列表 | 無 |
train_data | 是 | 訓練數據文件的OSS路徑。 | oss://path/to/train_*.tfrecord | 無 |
test_data | 是 | 訓練過程中,評估數據的OSS路徑。 | oss://path/to/test_*.tfrecord | 無 |
train_batch_size | 是 | 訓練的batch_size。 | INT,例如32。 | 無 |
test_batch_size | 是 | 評估的batch_size。 | INT,例如32。 | 無 |
train_num_readers | 否 | 訓練數據并發讀取線程數。 | INT | 4 |
model_dir | 是 | 訓練的OSS目錄。 | oss://path/to/model | 無 |
pretrained_model | 否 | 預訓練模型OSS路徑。如果指定該路徑,則在該模型基礎上進行微調。 | oss://pai-vision-data-sh/pretrained_models/inception_v4.ckpt | “” |
use_pretrained_model | 否 | 是否使用預訓練模型。 | BOOL | true |
num_epochs | 是 | 訓練迭代次數。取值1表示對所有訓練數據都進行一次迭代。 | INT,例如40。 | 無 |
num_test_example | 否 | 訓練過程中評估數據條目數。取值 -1表示使用所有測試數據作為評估數據。 | INT,例如2000。 | -1 |
num_visualizations | 否 | 評估過程可視化顯示的樣本數量。 | INT | 10 |
save_checkpoint_epochs | 否 | 保存Checkpoint的頻率,以Epoch為單位。取值為1表示每完成一次訓練就保存一次Checkpoint。 | INT | 1 |
num_train_images | 否 | 總的訓練樣本數。如果使用自己生成的TFRecord,則需要指定該參數。 | INT | 0 |
label_map_path | 否 | 類別映射文件。如果使用自己生成的TFRecord,則需要指定該參數。 | STRING | ”” |
相關文檔
與圖像分類模型不同,多標簽分類的多個類別并不互斥,圖像多標簽分類模型會輸出識別概率達到一定閾值的所有類別。您可以將生成的模型部署至EAS,詳情請參見服務部署:控制臺。