closes #41
This commit is contained in:
		
							parent
							
								
									bc801648a3
								
							
						
					
					
						commit
						ff9aca2699
					
				| @ -79,7 +79,7 @@ func generateCvs(c *Context, run_path string, model_id string) (count int, err e | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	data, err := c.Db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1;", model_id) | ||||
| 	data, err := c.Db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2;", model_id, model_classes.DATA_POINT_MODE_TRAINING) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| @ -74,14 +74,23 @@ def prepare_dataset(ds: tf.data.Dataset) -> tf.data.Dataset: | ||||
|   ds = configure_for_performance(ds) | ||||
|   return ds | ||||
| 
 | ||||
| def filterDataset(path): | ||||
|   path = tf.strings.regex_replace(path, DATA_DIR_PREPARE, "") | ||||
|    | ||||
|   path = tf.strings.regex_replace(path, ".jpg", "") | ||||
|   path = tf.strings.regex_replace(path, ".jpeg", "") | ||||
|    | ||||
|   return tf.reshape(table.lookup(tf.strings.as_string([path])), []) != -1 | ||||
| 
 | ||||
| seed = random.randint(0, 100000000) | ||||
| 
 | ||||
| batch_size = 100 | ||||
| 
 | ||||
| # Read all the files from the direcotry | ||||
| list_ds = tf.data.Dataset.list_files(str(f'{DATA_DIR}/*'), shuffle=False) | ||||
| list_ds = list_ds.filter(filterDataset) | ||||
| 
 | ||||
| image_count = len(list_ds) | ||||
| image_count = len(list(list_ds.as_numpy_iterator())) | ||||
| 
 | ||||
| list_ds = list_ds.shuffle(image_count, seed=seed) | ||||
| 
 | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user