feat: closes #24
This commit is contained in:
parent
f68f0e2444
commit
84b9c40a72
@ -110,7 +110,7 @@ func handleRun(handle *Handle) {
|
||||
|
||||
root := tg.NewRoot()
|
||||
|
||||
tf_img := ReadPNG(root, path.Join(run_path, "img.png"), 3)
|
||||
tf_img := ReadPNG(root, path.Join(run_path, "img.png"), int64(model.ImageMode))
|
||||
|
||||
exec_results := tg.Exec(root, []tf.Output{tf_img.Value()}, nil, &tf.SessionOptions{})
|
||||
inputImage, err:= tf.NewTensor(exec_results[0].Value())
|
||||
@ -130,8 +130,6 @@ func handleRun(handle *Handle) {
|
||||
vi := 0
|
||||
var predictions = results[0].Value().([][]float32)[0]
|
||||
|
||||
fmt.Println(predictions)
|
||||
|
||||
for i, v := range predictions {
|
||||
if v > vmax {
|
||||
vi = i
|
||||
@ -155,7 +153,6 @@ func handleRun(handle *Handle) {
|
||||
var name string
|
||||
if err = rows.Scan(&name); err != nil { return nil }
|
||||
|
||||
|
||||
LoadDefineTemplate(w, "/models/edit.html", "run-model-card", c.AddMap(AnyMap{
|
||||
"Model": model,
|
||||
"Result": name,
|
||||
|
@ -19,26 +19,11 @@ import (
|
||||
|
||||
func MakeDefenition(db *sql.DB, model_id string, target_accuracy int) (id string, err error) {
|
||||
id = ""
|
||||
_, err = db.Exec("insert into model_definition (model_id, target_accuracy) values ($1, $2);", model_id, target_accuracy)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
rows, err := db.Query("select id from model_definition where model_id=$1 order by created_on DESC;", model_id)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
rows, err := db.Query("insert into model_definition (model_id, target_accuracy) values ($1, $2) returning id;", model_id, target_accuracy)
|
||||
if err != nil { return }
|
||||
defer rows.Close()
|
||||
|
||||
if !rows.Next() {
|
||||
return id, errors.New("Something wrong!")
|
||||
}
|
||||
|
||||
if !rows.Next() { return id, errors.New("Something wrong!") }
|
||||
err = rows.Scan(&id)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@ -103,7 +88,7 @@ func generateCvs(handle *Handle, run_path string, model_id string) (count int,
|
||||
return
|
||||
}
|
||||
|
||||
func trainDefinition(handle *Handle, model_id string, definition_id string) (accuracy float64, err error) {
|
||||
func trainDefinition(handle *Handle, model *BaseModel, definition_id string) (accuracy float64, err error) {
|
||||
accuracy = 0
|
||||
layers, err := handle.Db.Query("select layer_type, shape from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id)
|
||||
if err != nil {
|
||||
@ -128,14 +113,14 @@ func trainDefinition(handle *Handle, model_id string, definition_id string) (acc
|
||||
}
|
||||
|
||||
// Generate run folder
|
||||
run_path := path.Join("/tmp", model_id, "defs", definition_id)
|
||||
run_path := path.Join("/tmp", model.Id, "defs", definition_id)
|
||||
|
||||
err = os.MkdirAll(run_path, os.ModePerm)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = generateCvs(handle, run_path, model_id)
|
||||
_, err = generateCvs(handle, run_path, model.Id)
|
||||
if err != nil { return }
|
||||
|
||||
// Create python script
|
||||
@ -153,8 +138,9 @@ func trainDefinition(handle *Handle, model_id string, definition_id string) (acc
|
||||
if err = tmpl.Execute(f, AnyMap{
|
||||
"Layers": got,
|
||||
"Size": got[0].Shape,
|
||||
"DataDir": path.Join(getDir(), "savedData", model_id, "data"),
|
||||
"DataDir": path.Join(getDir(), "savedData", model.Id, "data"),
|
||||
"RunPath": run_path,
|
||||
"ColorMode": model.ImageMode,
|
||||
}); err != nil {
|
||||
return
|
||||
}
|
||||
@ -167,7 +153,7 @@ func trainDefinition(handle *Handle, model_id string, definition_id string) (acc
|
||||
}
|
||||
|
||||
// Copy result around
|
||||
result_path := path.Join("savedData", model_id, "defs", definition_id)
|
||||
result_path := path.Join("savedData", model.Id, "defs", definition_id)
|
||||
|
||||
if err = os.MkdirAll(result_path, os.ModePerm); err != nil {
|
||||
return
|
||||
@ -235,7 +221,7 @@ func trainModel(handle *Handle, model *BaseModel) {
|
||||
}
|
||||
|
||||
for _, def := range definitions {
|
||||
accuracy, err := trainDefinition(handle, model.Id, def.id)
|
||||
accuracy, err := trainDefinition(handle, model, def.id)
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to train definition!Err:\n")
|
||||
fmt.Println(err)
|
||||
|
@ -10,6 +10,7 @@ type BaseModel struct {
|
||||
Status int
|
||||
Id string
|
||||
|
||||
ImageMode int
|
||||
Width int
|
||||
Height int
|
||||
}
|
||||
@ -30,7 +31,7 @@ const (
|
||||
var ModelNotFoundError = errors.New("Model not found error")
|
||||
|
||||
func GetBaseModel(db *sql.DB, id string) (base *BaseModel, err error) {
|
||||
rows, err := db.Query("select name, status, id, width, height from models where id=$1;", id)
|
||||
rows, err := db.Query("select name, status, id, width, height, color_mode from models where id=$1;", id)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -41,10 +42,17 @@ func GetBaseModel(db *sql.DB, id string) (base *BaseModel, err error) {
|
||||
}
|
||||
|
||||
base = &BaseModel{}
|
||||
err = rows.Scan(&base.Name, &base.Status, &base.Id, &base.Width, &base.Height)
|
||||
var colorMode string
|
||||
err = rows.Scan(&base.Name, &base.Status, &base.Id, &base.Width, &base.Height, &colorMode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch colorMode {
|
||||
case "greyscale":
|
||||
base.ImageMode = 1
|
||||
default:
|
||||
panic("unkown color mode")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -32,10 +32,7 @@ def pathToLabel(path):
|
||||
#return tf.strings.as_string([path])
|
||||
|
||||
def decode_image(img):
|
||||
# channels were reduced to 1 since image is grayscale
|
||||
# TODO chnage channel number based if grayscale
|
||||
img = tf.io.decode_png(img, channels=3)
|
||||
|
||||
img = tf.io.decode_png(img, channels={{.ColorMode}})
|
||||
return tf.image.resize(img, image_size)
|
||||
|
||||
def process_path(path):
|
||||
|
Loading…
Reference in New Issue
Block a user