feat: closes #39
This commit is contained in:
		
							parent
							
								
									c7c6cfcd00
								
							
						
					
					
						commit
						f163e25fba
					
				| @ -92,7 +92,34 @@ func handleEdit(handle *Handle) { | ||||
| 				"Model": model, | ||||
| 			})) | ||||
|         case TRAINING: | ||||
|             fallthrough | ||||
|              | ||||
|             type defrow struct { | ||||
|                 Status int | ||||
|                 EpochProgress int | ||||
|                 Accuracy int | ||||
|             } | ||||
| 
 | ||||
|             def_rows, err := c.Db.Query("select status, epoch_progress, accuracy from model_definition where model_id=$1", model.Id) | ||||
|             if err != nil { | ||||
|                 return c.Error500(err) | ||||
|             } | ||||
|             defer def_rows.Close() | ||||
| 
 | ||||
|             defs := []defrow{} | ||||
| 
 | ||||
|             for def_rows.Next() { | ||||
|                 var def defrow | ||||
|                 err = def_rows.Scan(&def.Status, &def.EpochProgress, &def.Accuracy) | ||||
|                 if err != nil { | ||||
|                     return c.Error500(err) | ||||
|                 } | ||||
|                 defs = append(defs, def) | ||||
|             } | ||||
| 
 | ||||
|             LoadBasedOnAnswer(c.Mode, w, "/models/edit.html", c.AddMap(AnyMap{ | ||||
|                 "Model": model, | ||||
|                 "Defs": defs, | ||||
|             })) | ||||
|         case PREPARING_ZIP_FILE: | ||||
|             LoadBasedOnAnswer(c.Mode, w, "/models/edit.html", c.AddMap(AnyMap{ | ||||
|                 "Model": model, | ||||
|  | ||||
| @ -160,6 +160,7 @@ func trainDefinition(c *Context, model *BaseModel, definition_id string) (accura | ||||
| 		"RunPath":   run_path, | ||||
| 		"ColorMode": model.ImageMode, | ||||
|         "Model": model, | ||||
|         "DefId": definition_id, | ||||
| 	}); err != nil { | ||||
| 		return | ||||
| 	} | ||||
| @ -239,6 +240,7 @@ func trainModel(c *Context, model *BaseModel) { | ||||
| 	} | ||||
| 
 | ||||
| 	for _, def := range definitions { | ||||
| 		ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_TRAINING) | ||||
| 		accuracy, err := trainDefinition(c, model, def.id) | ||||
| 		if err != nil { | ||||
| 			c.Logger.Error("Failed to train definition!Err:") | ||||
| @ -480,4 +482,58 @@ func handleTrain(handle *Handle) { | ||||
| 		Redirect("/models/edit?id="+model.Id, c.Mode, w, r) | ||||
| 		return nil | ||||
| 	}) | ||||
| 
 | ||||
| 	handle.Get("/model/epoch/update", func(w http.ResponseWriter, r *http.Request, c *Context) *Error { | ||||
|         // TODO check auth level
 | ||||
| 		if c.Mode != NORMAL { | ||||
|             // This should only handle normal requests
 | ||||
|             c.Logger.Warn("This function only works with normal") | ||||
|             return c.UnsafeErrorCode(nil, 400, nil) | ||||
| 		} | ||||
| 
 | ||||
|         f := r.URL.Query() | ||||
| 
 | ||||
|         if !CheckId(f, "model_id") || !CheckId(f, "definition") || CheckEmpty(f, "epoch") { | ||||
|             c.Logger.Warn("Invalid: model_id or definition or epoch") | ||||
|             return c.UnsafeErrorCode(nil, 400, nil) | ||||
|         } | ||||
| 
 | ||||
| 		model_id := f.Get("model_id") | ||||
| 		def_id := f.Get("definition") | ||||
|         epoch, err := strconv.Atoi(f.Get("epoch")) | ||||
|         if err != nil { | ||||
|             c.Logger.Warn("Epoch is not a number") | ||||
|             // No need to improve message because this function is only called internaly
 | ||||
|             return c.UnsafeErrorCode(nil, 400, nil) | ||||
|         } | ||||
| 
 | ||||
|         rows, err := c.Db.Query("select md.status from model_definition as md where md.model_id=$1 and md.id=$2", model_id, def_id) | ||||
|         if err != nil { | ||||
|             return c.Error500(err) | ||||
|         } | ||||
|         defer rows.Close() | ||||
| 
 | ||||
|         if !rows.Next() { | ||||
|             c.Logger.Error("Could not get status of model definition") | ||||
|             return c.Error500(nil) | ||||
|         } | ||||
| 
 | ||||
|         var status int | ||||
|         err = rows.Scan(&status) | ||||
|         if err != nil { | ||||
|             return c.Error500(err) | ||||
|         } | ||||
| 
 | ||||
|         if status != 3 { | ||||
|             c.Logger.Warn("Definition not on status 3(training)", "status", status) | ||||
|             // No need to improve message because this function is only called internaly
 | ||||
|             return c.UnsafeErrorCode(nil, 400, nil) | ||||
|         } | ||||
| 
 | ||||
|         _, err = c.Db.Exec("update model_definition set epoch_progress=$1 where id=$2", epoch, def_id) | ||||
|         if err != nil { | ||||
|             return c.Error500(err) | ||||
|         } | ||||
| 		return nil | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| @ -387,6 +387,14 @@ func (c Context) ErrorCode(err error, code int, data AnyMap) *Error { | ||||
| 	return &Error{code, nil, c.AddMap(data)} | ||||
| } | ||||
| 
 | ||||
| func (c Context) UnsafeErrorCode(err error, code int, data AnyMap) *Error { | ||||
| 	if err != nil { | ||||
| 		c.Logger.Errorf("Something went wrong returning with: %d\n.Err:\n", code) | ||||
| 		c.Logger.Error(err) | ||||
| 	} | ||||
| 	return &Error{code, nil, c.AddMap(data)} | ||||
| } | ||||
| 
 | ||||
| func (c Context) Error500(err error) *Error { | ||||
| 	return c.ErrorCode(err, http.StatusInternalServerError, nil) | ||||
| } | ||||
| @ -414,6 +422,12 @@ func (x Handle) createContext(handler *Handle, mode AnswerType, r *http.Request) | ||||
| 
 | ||||
| 	var token *string | ||||
| 
 | ||||
|     logger := log.NewWithOptions(os.Stdout, log.Options{ | ||||
|         ReportTimestamp: true, | ||||
|         TimeFormat: time.Kitchen, | ||||
|         Prefix: r.URL.Path, | ||||
|     }) | ||||
| 
 | ||||
| 	for _, r := range r.Cookies() { | ||||
| 		if r.Name == "auth" { | ||||
| 			token = &r.Value | ||||
| @ -425,6 +439,8 @@ func (x Handle) createContext(handler *Handle, mode AnswerType, r *http.Request) | ||||
| 	if token == nil { | ||||
| 		return &Context{ | ||||
| 			Mode: mode, | ||||
|             Logger: logger, | ||||
|             Db: handler.Db, | ||||
| 		}, nil | ||||
| 	} | ||||
| 
 | ||||
| @ -433,12 +449,6 @@ func (x Handle) createContext(handler *Handle, mode AnswerType, r *http.Request) | ||||
| 		return nil, errors.Join(err, LogoffError) | ||||
| 	} | ||||
| 
 | ||||
|     logger := log.NewWithOptions(os.Stdout, log.Options{ | ||||
|         ReportTimestamp: true, | ||||
|         TimeFormat: time.Kitchen, | ||||
|         Prefix: r.URL.Path, | ||||
|     }) | ||||
| 
 | ||||
| 	return &Context{token, user, mode, logger, handler.Db}, nil | ||||
| } | ||||
| 
 | ||||
|  | ||||
| @ -55,7 +55,8 @@ create table if not exists model_definition ( | ||||
|     -- 4: Tranied | ||||
|     -- 5: Ready | ||||
|     status integer default 1, | ||||
|     created_on timestamp default current_timestamp | ||||
|     created_on timestamp default current_timestamp, | ||||
|     epoch_progress integer default 0 | ||||
| ); | ||||
| 
 | ||||
| -- drop table if exists model_definition_layer; | ||||
|  | ||||
| @ -434,7 +434,20 @@ | ||||
|             {{/* TODO improve this */}} | ||||
|             Training the model...<br/> | ||||
|             {{/* TODO Add progress status on definitions */}} | ||||
|             {{/* TODO Add aility to stop training */}} | ||||
|             {{ range .Defs}} | ||||
|                 <div> | ||||
|                     <div> | ||||
|                         {{.Status}} | ||||
|                     </div> | ||||
|                     <div> | ||||
|                         {{.EpochProgress}} | ||||
|                     </div> | ||||
|                     <div> | ||||
|                         {{.Accuracy}} | ||||
|                     </div> | ||||
|                 </div> | ||||
|             {{ end }} | ||||
|             {{/* TODO Add ability to stop training */}} | ||||
|         </div> | ||||
|       {{/* Model Ready */}} | ||||
|       {{ else if (eq .Model.Status 5)}} | ||||
|  | ||||
| @ -4,6 +4,14 @@ import pandas as pd | ||||
| from tensorflow import keras | ||||
| from tensorflow.data import AUTOTUNE | ||||
| from keras import layers, losses, optimizers | ||||
| import requests | ||||
| 
 | ||||
| class NotifyServerCallback(tf.keras.callbacks.Callback): | ||||
|     def on_epoch_begin(self, epoch, *args, **kwargs): | ||||
|         if (epoch % 5) == 0: | ||||
|             # TODO change this | ||||
|             requests.get(f'http://localhost:8000/model/epoch/update?model_id={{.Model.Id}}&epoch={epoch}&definition={{.DefId}}') | ||||
| 
 | ||||
| 
 | ||||
| DATA_DIR = "{{ .DataDir }}" | ||||
| image_size = ({{ .Size }}) | ||||
| @ -26,11 +34,15 @@ DATA_DIR_PREPARE = DATA_DIR + "/" | ||||
| #based on https://www.tensorflow.org/tutorials/load_data/images | ||||
| def pathToLabel(path): | ||||
|   path = tf.strings.regex_replace(path, DATA_DIR_PREPARE, "") | ||||
|   {{ if eq .Model.Format "png" }} | ||||
|   path = tf.strings.regex_replace(path, ".png", "") | ||||
|   {{ else if eq .Model.Format "jpeg" }} | ||||
|   path = tf.strings.regex_replace(path, ".jpg", "") | ||||
|   path = tf.strings.regex_replace(path, ".jpeg", "") | ||||
|   path = tf.strings.regex_replace(path, ".png", "") | ||||
|   {{ else }} | ||||
|   ERROR | ||||
|   {{ end }} | ||||
|   return table.lookup(tf.strings.as_string([path])) | ||||
|   #return tf.strings.as_string([path]) | ||||
| 
 | ||||
| def decode_image(img): | ||||
|   {{ if eq .Model.Format "png" }} | ||||
| @ -100,7 +112,7 @@ model.compile( | ||||
|     optimizer=tf.keras.optimizers.Adam(), | ||||
|     metrics=['accuracy']) | ||||
| 
 | ||||
| his = model.fit(dataset, validation_data= dataset_validation, epochs=50) | ||||
| his = model.fit(dataset, validation_data= dataset_validation, epochs=50, callbacks=[NotifyServerCallback()]) | ||||
| 
 | ||||
| acc = his.history["accuracy"] | ||||
| 
 | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user